From e9755d87a641b02daf82604ef23ce09647158cd4 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sun, 13 Mar 2022 17:05:44 +0800 Subject: [PATCH] [Lint] One parameter/argument per line for C++ code (#22725) It's really annoying to deal with parameter/argument conflicts. This is even frustrating when we merge code from the community to Ant's internal code base with hundreds of conflicts caused by parameters/arguments. In this PR, I updated the clang-format style to make parameters/arguments stay on different lines if they can't fit into a single line. There are several benefits: * Conflict resolving is easier. * Less potential human mistakes when resolving conflicts. * Git history and Git blame are more straightforward. * Better readability. * Align with the new Python format style. --- .clang-format | 2 + .../extract_compile_command.cc | 7 +- cpp/include/ray/api.h | 13 +- cpp/include/ray/api/actor_creator.h | 6 +- cpp/include/ray/api/actor_handle.h | 4 +- cpp/include/ray/api/actor_task_caller.h | 9 +- cpp/include/ray/api/arguments.h | 12 +- cpp/include/ray/api/function_manager.h | 33 +- cpp/include/ray/api/logging.h | 3 +- cpp/include/ray/api/ray_config.h | 2 + cpp/include/ray/api/ray_remote.h | 3 +- cpp/include/ray/api/ray_runtime.h | 9 +- cpp/include/ray/api/task_caller.h | 6 +- cpp/include/ray/api/task_options.h | 3 +- cpp/include/ray/api/wait_result.h | 4 +- cpp/src/ray/config_internal.cc | 32 +- cpp/src/ray/runtime/abstract_ray_runtime.cc | 26 +- cpp/src/ray/runtime/abstract_ray_runtime.h | 3 +- cpp/src/ray/runtime/local_mode_ray_runtime.cc | 3 +- cpp/src/ray/runtime/logging.cc | 3 +- .../runtime/object/local_mode_object_store.cc | 20 +- .../runtime/object/local_mode_object_store.h | 3 +- .../ray/runtime/object/native_object_store.cc | 9 +- .../ray/runtime/object/native_object_store.h | 3 +- cpp/src/ray/runtime/object/object_store.h | 4 +- .../runtime/task/local_mode_task_submitter.cc | 58 +- .../ray/runtime/task/native_task_submitter.cc | 22 +- cpp/src/ray/runtime/task/task_executor.cc | 25 +- cpp/src/ray/runtime/task/task_executor.h | 9 +- cpp/src/ray/test/api_test.cc | 14 +- cpp/src/ray/test/cluster/cluster_mode_test.cc | 8 +- cpp/src/ray/test/cluster/counter.cc | 17 +- cpp/src/ray/test/cluster/plus.cc | 14 +- cpp/src/ray/test/examples/simple_kv_store.cc | 4 +- cpp/src/ray/util/process_helper.cc | 25 +- cpp/src/ray/util/process_helper.h | 3 +- cpp/src/ray/util/util.cc | 6 +- cpp/src/ray/worker/default_worker.cc | 4 +- src/mock/ray/core_worker/actor_creator.h | 26 +- src/mock/ray/core_worker/core_worker.h | 87 ++- src/mock/ray/core_worker/lease_policy.h | 10 +- src/mock/ray/core_worker/task_manager.h | 45 +- .../transport/direct_actor_transport.h | 37 +- src/mock/ray/gcs/gcs_client/accessor.h | 219 ++++-- .../ray/gcs/gcs_server/gcs_actor_manager.h | 30 +- .../ray/gcs/gcs_server/gcs_actor_scheduler.h | 56 +- .../gcs/gcs_server/gcs_heartbeat_manager.h | 9 +- src/mock/ray/gcs/gcs_server/gcs_job_manager.h | 33 +- src/mock/ray/gcs/gcs_server/gcs_kv_manager.h | 41 +- .../ray/gcs/gcs_server/gcs_node_manager.h | 21 +- .../gcs_server/gcs_placement_group_manager.h | 18 +- .../gcs_placement_group_scheduler.h | 62 +- .../ray/gcs/gcs_server/gcs_resource_manager.h | 15 +- .../gcs/gcs_server/gcs_resource_scheduler.h | 6 +- .../ray/gcs/gcs_server/gcs_table_storage.h | 26 +- .../ray/gcs/gcs_server/gcs_worker_manager.h | 18 +- .../ray/gcs/gcs_server/stats_handler_impl.h | 9 +- src/mock/ray/gcs/pubsub/gcs_pub_sub.h | 7 +- .../gcs/store_client/in_memory_store_client.h | 70 +- .../ray/gcs/store_client/redis_store_client.h | 70 +- src/mock/ray/gcs/store_client/store_client.h | 70 +- src/mock/ray/pubsub/publisher.h | 42 +- src/mock/ray/pubsub/subscriber.h | 38 +- src/mock/ray/raylet/agent_manager.h | 21 +- src/mock/ray/raylet/dependency_manager.h | 3 +- src/mock/ray/raylet/node_manager.h | 66 +- .../raylet/scheduling/cluster_task_manager.h | 46 +- .../cluster_task_manager_interface.h | 45 +- src/mock/ray/raylet/worker.h | 22 +- src/mock/ray/raylet/worker_pool.h | 51 +- src/mock/ray/raylet_client/raylet_client.h | 107 ++- src/mock/ray/rpc/worker/core_worker_client.h | 66 +- src/ray/common/asio/asio_util.h | 3 +- src/ray/common/asio/io_service_pool.h | 1 + src/ray/common/asio/periodical_runner.cc | 21 +- src/ray/common/asio/periodical_runner.h | 3 +- src/ray/common/buffer.h | 3 +- src/ray/common/bundle_spec.cc | 3 +- src/ray/common/bundle_spec.h | 3 +- src/ray/common/client_connection.cc | 72 +- src/ray/common/client_connection.h | 29 +- src/ray/common/event_stats.cc | 10 +- src/ray/common/event_stats.h | 3 +- src/ray/common/function_descriptor.cc | 6 +- src/ray/common/grpc_util.h | 14 +- src/ray/common/id.cc | 39 +- src/ray/common/id.h | 12 +- src/ray/common/network_util.cc | 18 +- src/ray/common/network_util.h | 13 +- src/ray/common/placement_group.h | 9 +- src/ray/common/ray_config_def.h | 6 +- src/ray/common/ray_object.cc | 3 +- src/ray/common/ray_object.h | 9 +- src/ray/common/task/task_spec.cc | 6 +- src/ray/common/task/task_spec.h | 6 +- src/ray/common/task/task_util.h | 45 +- src/ray/common/test/client_connection_test.cc | 74 +- src/ray/common/test_util.cc | 26 +- src/ray/common/test_util.h | 6 +- src/ray/core_worker/actor_handle.cc | 52 +- src/ray/core_worker/actor_handle.h | 18 +- src/ray/core_worker/actor_manager.cc | 69 +- src/ray/core_worker/actor_manager.h | 24 +- src/ray/core_worker/common.h | 28 +- src/ray/core_worker/context.cc | 3 +- src/ray/core_worker/core_worker.cc | 683 ++++++++++++------ src/ray/core_worker/core_worker.h | 131 ++-- src/ray/core_worker/core_worker_options.h | 7 +- src/ray/core_worker/core_worker_process.cc | 8 +- src/ray/core_worker/future_resolver.cc | 8 +- src/ray/core_worker/future_resolver.h | 3 +- .../core_worker/gcs_server_address_updater.cc | 7 +- .../core_worker/gcs_server_address_updater.h | 3 +- src/ray/core_worker/lease_policy.h | 3 +- .../java/io_ray_runtime_RayNativeRuntime.cc | 84 ++- .../java/io_ray_runtime_RayNativeRuntime.h | 24 +- .../io_ray_runtime_actor_NativeActorHandle.cc | 7 +- .../io_ray_runtime_actor_NativeActorHandle.h | 3 +- .../io_ray_runtime_gcs_GlobalStateAccessor.cc | 18 +- .../io_ray_runtime_gcs_GlobalStateAccessor.h | 23 +- .../io_ray_runtime_metric_NativeMetric.cc | 111 ++- ...io_ray_runtime_object_NativeObjectStore.cc | 63 +- .../io_ray_runtime_object_NativeObjectStore.h | 12 +- ...io_ray_runtime_task_NativeTaskSubmitter.cc | 71 +- .../io_ray_runtime_task_NativeTaskSubmitter.h | 16 +- src/ray/core_worker/lib/java/jni_init.cc | 29 +- src/ray/core_worker/lib/java/jni_utils.h | 73 +- .../core_worker/object_recovery_manager.cc | 20 +- src/ray/core_worker/object_recovery_manager.h | 24 +- src/ray/core_worker/profiling.cc | 6 +- src/ray/core_worker/profiling.h | 3 +- src/ray/core_worker/reference_count.cc | 101 ++- src/ray/core_worker/reference_count.h | 65 +- src/ray/core_worker/reference_count_test.cc | 355 +++++---- .../memory_store/memory_store.cc | 68 +- .../memory_store/memory_store.h | 23 +- .../store_provider/plasma_store_provider.cc | 101 ++- .../store_provider/plasma_store_provider.h | 40 +- src/ray/core_worker/task_manager.cc | 37 +- src/ray/core_worker/task_manager.h | 41 +- .../core_worker/test/actor_creator_test.cc | 3 +- .../core_worker/test/actor_manager_test.cc | 108 ++- src/ray/core_worker/test/core_worker_test.cc | 147 ++-- .../test/direct_actor_transport_mock_test.cc | 7 +- .../test/direct_actor_transport_test.cc | 53 +- .../test/direct_task_transport_mock_test.cc | 10 +- .../test/direct_task_transport_test.cc | 449 +++++++++--- src/ray/core_worker/test/memory_store_test.cc | 10 +- src/ray/core_worker/test/mock_worker.cc | 17 +- .../test/object_recovery_manager_test.cc | 81 ++- .../core_worker/test/scheduling_queue_test.cc | 15 +- src/ray/core_worker/test/task_manager_test.cc | 13 +- .../transport/actor_scheduling_queue.cc | 23 +- .../transport/actor_scheduling_queue.h | 9 +- .../transport/actor_scheduling_util.cc | 6 +- .../transport/actor_scheduling_util.h | 6 +- .../transport/dependency_resolver.cc | 54 +- .../transport/direct_actor_task_submitter.cc | 24 +- .../transport/direct_actor_task_submitter.h | 26 +- .../transport/direct_actor_transport.cc | 74 +- .../transport/direct_actor_transport.h | 6 +- .../transport/direct_task_transport.cc | 107 ++- .../transport/direct_task_transport.h | 33 +- .../transport/normal_scheduling_queue.cc | 23 +- .../transport/normal_scheduling_queue.h | 3 +- .../out_of_order_actor_scheduling_queue.cc | 26 +- .../out_of_order_actor_scheduling_queue.h | 9 +- .../core_worker/transport/scheduling_queue.h | 3 +- src/ray/gcs/asio.cc | 14 +- src/ray/gcs/gcs_client/accessor.cc | 132 ++-- src/ray/gcs/gcs_client/accessor.h | 50 +- src/ray/gcs/gcs_client/gcs_client.cc | 19 +- src/ray/gcs/gcs_client/gcs_client.h | 9 +- .../gcs/gcs_client/global_state_accessor.cc | 17 +- .../gcs/gcs_client/global_state_accessor.h | 4 +- .../gcs/gcs_client/test/gcs_client_test.cc | 20 +- .../test/global_state_accessor_test.cc | 10 +- .../gcs/gcs_server/gcs_actor_distribution.cc | 30 +- .../gcs/gcs_server/gcs_actor_distribution.h | 9 +- src/ray/gcs/gcs_server/gcs_actor_manager.cc | 147 ++-- src/ray/gcs/gcs_server/gcs_actor_manager.h | 15 +- src/ray/gcs/gcs_server/gcs_actor_scheduler.cc | 28 +- src/ray/gcs/gcs_server/gcs_actor_scheduler.h | 19 +- .../gcs/gcs_server/gcs_heartbeat_manager.cc | 8 +- src/ray/gcs/gcs_server/gcs_job_manager.cc | 9 +- src/ray/gcs/gcs_server/gcs_job_manager.h | 3 +- src/ray/gcs/gcs_server/gcs_kv_manager.cc | 76 +- src/ray/gcs/gcs_server/gcs_kv_manager.h | 59 +- src/ray/gcs/gcs_server/gcs_node_manager.cc | 63 +- src/ray/gcs/gcs_server/gcs_node_manager.h | 3 +- .../gcs_server/gcs_placement_group_manager.cc | 101 +-- .../gcs_server/gcs_placement_group_manager.h | 3 +- .../gcs_placement_group_scheduler.cc | 81 ++- .../gcs_placement_group_scheduler.h | 22 +- .../gcs_server/gcs_redis_failure_detector.cc | 3 +- .../gcs/gcs_server/gcs_resource_manager.cc | 18 +- src/ray/gcs/gcs_server/gcs_resource_manager.h | 3 +- .../gcs_server/gcs_resource_report_poller.cc | 16 +- .../gcs_server/gcs_resource_report_poller.h | 13 +- .../gcs/gcs_server/gcs_resource_scheduler.cc | 6 +- src/ray/gcs/gcs_server/gcs_server.cc | 87 ++- src/ray/gcs/gcs_server/gcs_server_main.cc | 9 +- src/ray/gcs/gcs_server/gcs_table_storage.cc | 32 +- src/ray/gcs/gcs_server/gcs_worker_manager.cc | 46 +- .../grpc_based_resource_broadcaster.cc | 3 +- .../grpc_based_resource_broadcaster.h | 6 +- src/ray/gcs/gcs_server/pubsub_handler.cc | 24 +- src/ray/gcs/gcs_server/ray_syncer.h | 3 +- src/ray/gcs/gcs_server/stats_handler_impl.cc | 7 +- .../gcs_server/test/gcs_actor_manager_test.cc | 192 +++-- .../test/gcs_actor_scheduler_mock_test.cc | 7 +- .../test/gcs_based_actor_scheduler_test.cc | 90 ++- .../gcs_server/test/gcs_job_manager_test.cc | 16 +- .../gcs_server/test/gcs_kv_manager_test.cc | 3 +- .../gcs_placement_group_manager_mock_test.cc | 9 +- .../test/gcs_placement_group_manager_test.cc | 49 +- .../gcs_placement_group_scheduler_test.cc | 25 +- .../test/gcs_resource_manager_test.cc | 16 +- .../test/gcs_resource_report_poller_test.cc | 34 +- .../test/gcs_resource_scheduler_test.cc | 13 +- .../gcs_server/test/gcs_server_rpc_test.cc | 59 +- .../gcs_server/test/gcs_server_test_util.h | 37 +- .../test/gcs_table_storage_test_base.h | 4 +- .../grpc_based_resource_broadcaster_test.cc | 3 +- .../test/raylet_based_actor_scheduler_test.cc | 85 ++- .../test/redis_gcs_table_storage_test.cc | 4 +- src/ray/gcs/pb_util.h | 25 +- src/ray/gcs/pubsub/gcs_pub_sub.cc | 76 +- src/ray/gcs/pubsub/gcs_pub_sub.h | 39 +- src/ray/gcs/pubsub/test/gcs_pub_sub_test.cc | 6 +- src/ray/gcs/redis_async_context.cc | 12 +- src/ray/gcs/redis_async_context.h | 7 +- src/ray/gcs/redis_client.cc | 46 +- src/ray/gcs/redis_client.h | 10 +- src/ray/gcs/redis_context.cc | 66 +- src/ray/gcs/redis_context.h | 35 +- .../store_client/in_memory_store_client.cc | 12 +- .../gcs/store_client/in_memory_store_client.h | 27 +- .../gcs/store_client/redis_store_client.cc | 81 ++- src/ray/gcs/store_client/redis_store_client.h | 39 +- src/ray/gcs/store_client/store_client.h | 21 +- .../test/redis_store_client_test.cc | 7 +- .../test/store_client_test_base.h | 28 +- src/ray/gcs/test/asio_test.cc | 6 +- src/ray/gcs/test/callback_reply_test.cc | 12 +- src/ray/gcs/test/gcs_test_util.h | 77 +- src/ray/internal/internal.cc | 3 +- src/ray/internal/internal.h | 3 +- src/ray/object_manager/chunk_object_reader.cc | 1 + .../object_manager/memory_object_reader.cc | 7 +- src/ray/object_manager/memory_object_reader.h | 3 +- src/ray/object_manager/object_buffer_pool.cc | 35 +- src/ray/object_manager/object_buffer_pool.h | 30 +- src/ray/object_manager/object_directory.h | 15 +- src/ray/object_manager/object_manager.cc | 155 ++-- src/ray/object_manager/object_manager.h | 46 +- src/ray/object_manager/object_reader.h | 6 +- .../ownership_based_object_directory.cc | 124 +++- .../ownership_based_object_directory.h | 15 +- src/ray/object_manager/plasma/client.cc | 210 ++++-- src/ray/object_manager/plasma/client.h | 42 +- src/ray/object_manager/plasma/common.h | 6 +- src/ray/object_manager/plasma/compat.h | 1 + src/ray/object_manager/plasma/connection.cc | 34 +- src/ray/object_manager/plasma/connection.h | 3 +- .../plasma/create_request_queue.cc | 15 +- .../plasma/create_request_queue.h | 16 +- src/ray/object_manager/plasma/dlmalloc.cc | 15 +- .../object_manager/plasma/eviction_policy.cc | 3 +- src/ray/object_manager/plasma/fling.cc | 5 +- .../plasma/get_request_queue.cc | 10 +- .../object_manager/plasma/get_request_queue.h | 6 +- src/ray/object_manager/plasma/malloc.cc | 4 +- src/ray/object_manager/plasma/malloc.h | 5 +- .../plasma/object_lifecycle_manager.cc | 9 +- .../plasma/object_lifecycle_manager.h | 6 +- .../object_manager/plasma/plasma_allocator.cc | 26 +- .../object_manager/plasma/plasma_allocator.h | 5 +- src/ray/object_manager/plasma/protocol.cc | 161 +++-- src/ray/object_manager/plasma/protocol.h | 91 ++- src/ray/object_manager/plasma/store.cc | 55 +- src/ray/object_manager/plasma/store.h | 19 +- src/ray/object_manager/plasma/store_runner.cc | 20 +- src/ray/object_manager/plasma/store_runner.h | 6 +- .../plasma/test/eviction_policy_test.cc | 7 +- .../plasma/test/fallback_allocator_test.cc | 7 +- .../test/object_lifecycle_manager_test.cc | 17 +- .../plasma/test/object_store_test.cc | 13 +- .../plasma/test/stats_collector_test.cc | 22 +- src/ray/object_manager/pull_manager.cc | 95 ++- src/ray/object_manager/pull_manager.h | 18 +- src/ray/object_manager/push_manager.cc | 3 +- src/ray/object_manager/push_manager.h | 4 +- .../object_manager/spilled_object_reader.cc | 44 +- .../object_manager/spilled_object_reader.h | 28 +- .../test/get_request_queue_test.cc | 24 +- .../test/object_buffer_pool_test.cc | 9 +- .../ownership_based_object_directory_test.cc | 67 +- .../object_manager/test/pull_manager_test.cc | 108 +-- .../test/spilled_object_test.cc | 177 +++-- src/ray/pubsub/mock_pubsub.h | 32 +- src/ray/pubsub/publisher.cc | 20 +- src/ray/pubsub/publisher.h | 9 +- src/ray/pubsub/subscriber.cc | 49 +- src/ray/pubsub/subscriber.h | 21 +- src/ray/pubsub/test/integration_test.cc | 70 +- src/ray/pubsub/test/publisher_test.cc | 164 +++-- src/ray/pubsub/test/subscriber_test.cc | 227 ++++-- src/ray/raylet/agent_manager.cc | 89 ++- src/ray/raylet/agent_manager.h | 6 +- src/ray/raylet/local_object_manager.cc | 105 +-- src/ray/raylet/local_object_manager.h | 17 +- src/ray/raylet/main.cc | 46 +- src/ray/raylet/node_manager.cc | 340 +++++---- src/ray/raylet/node_manager.h | 15 +- .../placement_group_resource_manager.cc | 5 +- src/ray/raylet/raylet.cc | 29 +- src/ray/raylet/raylet.h | 9 +- .../scheduling/cluster_resource_data.cc | 7 +- .../scheduling/cluster_resource_manager.cc | 8 +- .../scheduling/cluster_resource_manager.h | 3 +- .../cluster_resource_manager_test.cc | 21 +- .../scheduling/cluster_resource_scheduler.cc | 92 ++- .../scheduling/cluster_resource_scheduler.h | 14 +- .../cluster_resource_scheduler_test.cc | 303 ++++---- .../raylet/scheduling/cluster_task_manager.cc | 36 +- .../raylet/scheduling/cluster_task_manager.h | 6 +- .../cluster_task_manager_interface.h | 6 +- .../scheduling/cluster_task_manager_test.cc | 222 +++--- src/ray/raylet/scheduling/internal.h | 7 +- .../scheduling/local_resource_manager.cc | 33 +- .../scheduling/local_resource_manager.h | 12 +- .../raylet/scheduling/local_task_manager.cc | 46 +- .../raylet/scheduling/local_task_manager.h | 18 +- .../policy/hybrid_scheduling_policy.cc | 24 +- .../policy/hybrid_scheduling_policy.h | 3 +- .../scheduling/policy/scheduling_options.h | 16 +- .../policy/scheduling_policy_test.cc | 49 +- .../scheduling/scheduler_resource_reporter.cc | 10 +- .../scheduling/scheduler_resource_reporter.h | 5 +- src/ray/raylet/scheduling/scheduler_stats.cc | 33 +- .../raylet/scheduling/scheduling_policy.cc | 257 +++++++ src/ray/raylet/scheduling/scheduling_policy.h | 129 ++++ .../scheduling/scheduling_policy_test.cc | 429 +++++++++++ .../raylet/test/local_object_manager_test.cc | 121 ++-- src/ray/raylet/wait_manager.cc | 6 +- src/ray/raylet/wait_manager.h | 12 +- src/ray/raylet/wait_manager_test.cc | 20 +- src/ray/raylet/worker.cc | 10 +- src/ray/raylet/worker.h | 13 +- src/ray/raylet/worker_pool.cc | 288 +++++--- src/ray/raylet/worker_pool.h | 58 +- src/ray/raylet/worker_pool_test.cc | 351 +++++---- src/ray/raylet_client/raylet_client.cc | 113 ++- src/ray/raylet_client/raylet_client.h | 77 +- .../rpc/agent_manager/agent_manager_client.h | 11 +- src/ray/rpc/client_call.h | 20 +- src/ray/rpc/gcs_server/gcs_rpc_client.h | 244 +++++-- src/ray/rpc/gcs_server/gcs_rpc_server.h | 53 +- src/ray/rpc/grpc_client.h | 46 +- src/ray/rpc/grpc_server.cc | 6 +- src/ray/rpc/grpc_server.h | 25 +- src/ray/rpc/metrics_agent_client.h | 11 +- .../rpc/node_manager/node_manager_client.h | 89 ++- .../rpc/node_manager/node_manager_server.h | 3 +- .../object_manager/object_manager_client.h | 15 +- .../object_manager/object_manager_server.h | 6 +- src/ray/rpc/runtime_env/runtime_env_client.h | 15 +- src/ray/rpc/server_call.h | 41 +- src/ray/rpc/test/grpc_server_client_test.cc | 32 +- src/ray/rpc/worker/core_worker_client.h | 215 ++++-- src/ray/stats/metric.cc | 3 +- src/ray/stats/metric.h | 52 +- src/ray/stats/metric_defs.cc | 133 +++- src/ray/stats/metric_defs.h | 46 +- src/ray/stats/metric_exporter.cc | 19 +- src/ray/stats/metric_exporter.h | 17 +- src/ray/stats/metric_exporter_client_test.cc | 4 +- src/ray/stats/stats.h | 7 +- src/ray/stats/stats_test.cc | 30 +- src/ray/util/event.cc | 30 +- src/ray/util/event.h | 26 +- src/ray/util/event_test.cc | 116 ++- src/ray/util/logging.cc | 21 +- src/ray/util/logging_test.cc | 4 +- src/ray/util/memory.cc | 11 +- src/ray/util/memory.h | 7 +- src/ray/util/process.cc | 35 +- src/ray/util/process.h | 13 +- src/ray/util/sample.h | 4 +- src/ray/util/sample_test.cc | 12 +- src/ray/util/sequencer.h | 3 +- src/ray/util/signal_test.cc | 3 +- src/ray/util/throttler.h | 1 + src/ray/util/throttler_test.cc | 4 +- src/ray/util/util.h | 3 +- 396 files changed, 10812 insertions(+), 5275 deletions(-) create mode 100644 src/ray/raylet/scheduling/scheduling_policy.cc create mode 100644 src/ray/raylet/scheduling/scheduling_policy.h create mode 100644 src/ray/raylet/scheduling/scheduling_policy_test.cc diff --git a/.clang-format b/.clang-format index 5c0f059e1..47de2b4cb 100644 --- a/.clang-format +++ b/.clang-format @@ -3,3 +3,5 @@ ColumnLimit: 90 DerivePointerAlignment: false IndentCaseLabels: false PointerAlignment: Right +BinPackArguments: false +BinPackParameters: false diff --git a/ci/generate_compile_commands/extract_compile_command.cc b/ci/generate_compile_commands/extract_compile_command.cc index e7a23911c..98f24ed30 100644 --- a/ci/generate_compile_commands/extract_compile_command.cc +++ b/ci/generate_compile_commands/extract_compile_command.cc @@ -37,7 +37,8 @@ namespace { using ::google::protobuf::io::CodedInputStream; using ::google::protobuf::io::FileInputStream; -bool ReadExtraAction(const std::string &path, blaze::ExtraActionInfo *info, +bool ReadExtraAction(const std::string &path, + blaze::ExtraActionInfo *info, blaze::CppCompileInfo *cpp_info) { int fd = ::open(path.c_str(), O_RDONLY, S_IREAD | S_IWRITE); if (fd < 0) { @@ -97,8 +98,8 @@ int main(int argc, char **argv) { std::vector args; args.push_back(cpp_info.tool()); - args.insert(args.end(), cpp_info.compiler_option().begin(), - cpp_info.compiler_option().end()); + args.insert( + args.end(), cpp_info.compiler_option().begin(), cpp_info.compiler_option().end()); if (std::find(args.begin(), args.end(), "-c") == args.end()) { args.push_back("-c"); args.push_back(cpp_info.source_file()); diff --git a/cpp/include/ray/api.h b/cpp/include/ray/api.h index 78b30790d..6dfab80f2 100644 --- a/cpp/include/ray/api.h +++ b/cpp/include/ray/api.h @@ -82,7 +82,8 @@ std::vector> Get(const std::vector> &object /// \return Two arrays, one containing locally available objects, one containing the /// rest. template -WaitResult Wait(const std::vector> &objects, int num_objects, +WaitResult Wait(const std::vector> &objects, + int num_objects, int timeout_ms); /// Create a `TaskCaller` for calling remote function. @@ -196,7 +197,8 @@ inline std::vector> Get(const std::vector> } template -inline WaitResult Wait(const std::vector> &objects, int num_objects, +inline WaitResult Wait(const std::vector> &objects, + int num_objects, int timeout_ms) { auto object_ids = ObjectRefsToObjectIDs(objects); auto results = @@ -214,9 +216,10 @@ inline WaitResult Wait(const std::vector> &objects, int num } inline ray::internal::ActorCreator Actor(PyActorClass func) { - ray::internal::RemoteFunctionHolder remote_func_holder( - func.module_name, func.function_name, func.class_name, - ray::internal::LangType::PYTHON); + ray::internal::RemoteFunctionHolder remote_func_holder(func.module_name, + func.function_name, + func.class_name, + ray::internal::LangType::PYTHON); return {ray::internal::GetRayRuntime().get(), std::move(remote_func_holder)}; } diff --git a/cpp/include/ray/api/actor_creator.h b/cpp/include/ray/api/actor_creator.h index a8b7389c7..de0d7154b 100644 --- a/cpp/include/ray/api/actor_creator.h +++ b/cpp/include/ray/api/actor_creator.h @@ -80,13 +80,15 @@ ActorHandle, is_python_v> ActorCreator::Remote(Args &&...a if constexpr (is_python_v) { using ArgsTuple = std::tuple; - Arguments::WrapArgs(/*cross_lang=*/true, &args_, + Arguments::WrapArgs(/*cross_lang=*/true, + &args_, std::make_index_sequence{}, std::forward(args)...); } else { StaticCheck(); using ArgsTuple = RemoveReference_t>; - Arguments::WrapArgs(/*cross_lang=*/false, &args_, + Arguments::WrapArgs(/*cross_lang=*/false, + &args_, std::make_index_sequence{}, std::forward(args)...); } diff --git a/cpp/include/ray/api/actor_handle.h b/cpp/include/ray/api/actor_handle.h index 38139737d..c1c897a92 100644 --- a/cpp/include/ray/api/actor_handle.h +++ b/cpp/include/ray/api/actor_handle.h @@ -45,8 +45,8 @@ class ActorHandle { std::is_same::value || std::is_base_of::value, "Class types must be same."); ray::internal::RemoteFunctionHolder remote_func_holder(actor_func); - return ray::internal::ActorTaskCaller(internal::GetRayRuntime().get(), id_, - std::move(remote_func_holder)); + return ray::internal::ActorTaskCaller( + internal::GetRayRuntime().get(), id_, std::move(remote_func_holder)); } template diff --git a/cpp/include/ray/api/actor_task_caller.h b/cpp/include/ray/api/actor_task_caller.h index 5fb2ff510..d14c3d26d 100644 --- a/cpp/include/ray/api/actor_task_caller.h +++ b/cpp/include/ray/api/actor_task_caller.h @@ -26,7 +26,8 @@ class ActorTaskCaller { public: ActorTaskCaller() = default; - ActorTaskCaller(RayRuntime *runtime, const std::string &id, + ActorTaskCaller(RayRuntime *runtime, + const std::string &id, RemoteFunctionHolder remote_function_holder) : runtime_(runtime), id_(id), @@ -68,13 +69,15 @@ ObjectRef> ActorTaskCaller::Remote( if constexpr (is_python_v) { using ArgsTuple = std::tuple; - Arguments::WrapArgs(/*cross_lang=*/true, &args_, + Arguments::WrapArgs(/*cross_lang=*/true, + &args_, std::make_index_sequence{}, std::forward(args)...); } else { StaticCheck(); using ArgsTuple = RemoveReference_t>>; - Arguments::WrapArgs(/*cross_lang=*/false, &args_, + Arguments::WrapArgs(/*cross_lang=*/false, + &args_, std::make_index_sequence{}, std::forward(args)...); } diff --git a/cpp/include/ray/api/arguments.h b/cpp/include/ray/api/arguments.h index cd7b7b729..555a744f2 100644 --- a/cpp/include/ray/api/arguments.h +++ b/cpp/include/ray/api/arguments.h @@ -26,7 +26,8 @@ namespace internal { class Arguments { public: template - static void WrapArgsImpl(bool cross_lang, std::vector *task_args, + static void WrapArgsImpl(bool cross_lang, + std::vector *task_args, InputArgTypes &&arg) { if constexpr (is_object_ref_v) { PushReferenceArg(task_args, std::forward(arg)); @@ -65,8 +66,10 @@ class Arguments { } template - static void WrapArgs(bool cross_lang, std::vector *task_args, - std::index_sequence, InputArgTypes &&...args) { + static void WrapArgs(bool cross_lang, + std::vector *task_args, + std::index_sequence, + InputArgTypes &&...args) { (void)std::initializer_list{ (WrapArgsImpl>( cross_lang, task_args, std::forward(args)), @@ -77,7 +80,8 @@ class Arguments { } private: - static void PushValueArg(std::vector *task_args, msgpack::sbuffer &&buffer, + static void PushValueArg(std::vector *task_args, + msgpack::sbuffer &&buffer, std::string_view meta_str = "") { /// Pass by value. TaskArg task_arg; diff --git a/cpp/include/ray/api/function_manager.h b/cpp/include/ray/api/function_manager.h index 9cff59b58..421a5a493 100644 --- a/cpp/include/ray/api/function_manager.h +++ b/cpp/include/ray/api/function_manager.h @@ -84,7 +84,8 @@ struct Invoker { return result; } - static inline msgpack::sbuffer ApplyMember(const Function &func, msgpack::sbuffer *ptr, + static inline msgpack::sbuffer ApplyMember(const Function &func, + msgpack::sbuffer *ptr, const ArgsBufferList &args_buffer) { using RetrunType = boost::callable_traits::return_type_t; using ArgsTuple = @@ -124,7 +125,8 @@ struct Invoker { } } - static inline bool GetArgsTuple(std::tuple<> &tup, const ArgsBufferList &args_buffer, + static inline bool GetArgsTuple(std::tuple<> &tup, + const ArgsBufferList &args_buffer, std::index_sequence<>) { return true; } @@ -155,7 +157,8 @@ struct Invoker { } template - static R CallInternal(const F &f, const std::index_sequence &, + static R CallInternal(const F &f, + const std::index_sequence &, std::tuple args) { (void)args; using ArgsTuple = boost::callable_traits::args_t; @@ -165,21 +168,23 @@ struct Invoker { template static std::enable_if_t::value, msgpack::sbuffer> CallMember( const F &f, Self *self, std::tuple args) { - CallMemberInternal(f, self, std::make_index_sequence{}, - std::move(args)); + CallMemberInternal( + f, self, std::make_index_sequence{}, std::move(args)); return PackVoid(); } template static std::enable_if_t::value, msgpack::sbuffer> CallMember( const F &f, Self *self, std::tuple args) { - auto r = CallMemberInternal(f, self, std::make_index_sequence{}, - std::move(args)); + auto r = CallMemberInternal( + f, self, std::make_index_sequence{}, std::move(args)); return PackReturnValue(r); } template - static R CallMemberInternal(const F &f, Self *self, const std::index_sequence &, + static R CallMemberInternal(const F &f, + Self *self, + const std::index_sequence &, std::tuple args) { (void)args; using ArgsTuple = boost::callable_traits::args_t; @@ -288,16 +293,20 @@ class FunctionManager { template bool RegisterNonMemberFunc(std::string const &name, Function f) { return map_invokers_ - .emplace(name, std::bind(&Invoker::Apply, std::move(f), - std::placeholders::_1)) + .emplace( + name, + std::bind(&Invoker::Apply, std::move(f), std::placeholders::_1)) .second; } template bool RegisterMemberFunc(std::string const &name, Function f) { return map_mem_func_invokers_ - .emplace(name, std::bind(&Invoker::ApplyMember, std::move(f), - std::placeholders::_1, std::placeholders::_2)) + .emplace(name, + std::bind(&Invoker::ApplyMember, + std::move(f), + std::placeholders::_1, + std::placeholders::_2)) .second; } diff --git a/cpp/include/ray/api/logging.h b/cpp/include/ray/api/logging.h index e24d8b1f5..a9a8862c0 100644 --- a/cpp/include/ray/api/logging.h +++ b/cpp/include/ray/api/logging.h @@ -75,7 +75,8 @@ class RayLogger { virtual std::ostream &Stream() = 0; }; -std::unique_ptr CreateRayLogger(const char *file_name, int line_number, +std::unique_ptr CreateRayLogger(const char *file_name, + int line_number, RayLoggerLevel severity); bool IsLevelEnabled(RayLoggerLevel log_level); diff --git a/cpp/include/ray/api/ray_config.h b/cpp/include/ray/api/ray_config.h index 146858b2d..c7880ab94 100644 --- a/cpp/include/ray/api/ray_config.h +++ b/cpp/include/ray/api/ray_config.h @@ -14,10 +14,12 @@ #pragma once #include + #include #include #include #include + #include "boost/optional.hpp" namespace ray { diff --git a/cpp/include/ray/api/ray_remote.h b/cpp/include/ray/api/ray_remote.h index 3aae04f2d..38080d011 100644 --- a/cpp/include/ray/api/ray_remote.h +++ b/cpp/include/ray/api/ray_remote.h @@ -60,7 +60,8 @@ inline static int RegisterRemoteFunctions(const T &t, U... u) { (void)std::initializer_list{ (FunctionManager::Instance().RegisterRemoteFunction( std::string(func_names[index].data(), func_names[index].length()), u), - index++, 0)...}; + index++, + 0)...}; return 0; } diff --git a/cpp/include/ray/api/ray_runtime.h b/cpp/include/ray/api/ray_runtime.h index cf5ab5efc..ad29e7928 100644 --- a/cpp/include/ray/api/ray_runtime.h +++ b/cpp/include/ray/api/ray_runtime.h @@ -29,7 +29,8 @@ namespace internal { struct RemoteFunctionHolder { RemoteFunctionHolder() = default; - RemoteFunctionHolder(const std::string &module_name, const std::string &function_name, + RemoteFunctionHolder(const std::string &module_name, + const std::string &function_name, const std::string &class_name = "", LangType lang_type = LangType::CPP) { this->module_name = module_name; @@ -61,7 +62,8 @@ class RayRuntime { virtual std::vector> Get( const std::vector &ids) = 0; - virtual std::vector Wait(const std::vector &ids, int num_objects, + virtual std::vector Wait(const std::vector &ids, + int num_objects, int timeout_ms) = 0; virtual std::string Call(const RemoteFunctionHolder &remote_function_holder, @@ -71,7 +73,8 @@ class RayRuntime { std::vector &args, const ActorCreationOptions &create_options) = 0; virtual std::string CallActor(const RemoteFunctionHolder &remote_function_holder, - const std::string &actor, std::vector &args, + const std::string &actor, + std::vector &args, const CallOptions &call_options) = 0; virtual void AddLocalReference(const std::string &id) = 0; virtual void RemoveLocalReference(const std::string &id) = 0; diff --git a/cpp/include/ray/api/task_caller.h b/cpp/include/ray/api/task_caller.h index cf49b0902..eca54c419 100644 --- a/cpp/include/ray/api/task_caller.h +++ b/cpp/include/ray/api/task_caller.h @@ -76,13 +76,15 @@ ObjectRef> TaskCaller::Remote( if constexpr (is_python_v) { using ArgsTuple = std::tuple; - Arguments::WrapArgs(/*cross_lang=*/true, &args_, + Arguments::WrapArgs(/*cross_lang=*/true, + &args_, std::make_index_sequence{}, std::forward(args)...); } else { StaticCheck(); using ArgsTuple = RemoveReference_t>; - Arguments::WrapArgs(/*cross_lang=*/false, &args_, + Arguments::WrapArgs(/*cross_lang=*/false, + &args_, std::make_index_sequence{}, std::forward(args)...); } diff --git a/cpp/include/ray/api/task_options.h b/cpp/include/ray/api/task_options.h index 5ecb01073..5bae8abca 100644 --- a/cpp/include/ray/api/task_options.h +++ b/cpp/include/ray/api/task_options.h @@ -64,7 +64,8 @@ struct PlacementGroupCreationOptions { class PlacementGroup { public: PlacementGroup() = default; - PlacementGroup(std::string id, PlacementGroupCreationOptions options, + PlacementGroup(std::string id, + PlacementGroupCreationOptions options, PlacementGroupState state = PlacementGroupState::UNRECOGNIZED) : id_(std::move(id)), options_(std::move(options)), state_(state) {} std::string GetID() const { return id_; } diff --git a/cpp/include/ray/api/wait_result.h b/cpp/include/ray/api/wait_result.h index 0dda52e4f..bbe333745 100644 --- a/cpp/include/ray/api/wait_result.h +++ b/cpp/include/ray/api/wait_result.h @@ -14,10 +14,10 @@ #pragma once -#include - #include +#include + namespace ray { /// \param T The type of object. diff --git a/cpp/src/ray/config_internal.cc b/cpp/src/ray/config_internal.cc index 4c6852da6..de0556f2f 100644 --- a/cpp/src/ray/config_internal.cc +++ b/cpp/src/ray/config_internal.cc @@ -25,11 +25,15 @@ ABSL_FLAG(std::string, ray_address, "", "The address of the Ray cluster to conne /// absl::flags does not provide a IsDefaultValue method, so use a non-empty dummy default /// value to support empty redis password. -ABSL_FLAG(std::string, ray_redis_password, "absl::flags dummy default value", +ABSL_FLAG(std::string, + ray_redis_password, + "absl::flags dummy default value", "Prevents external clients without the password from connecting to Redis " "if provided."); -ABSL_FLAG(std::string, ray_code_search_path, "", +ABSL_FLAG(std::string, + ray_code_search_path, + "", "A list of directories or files of dynamic libraries that specify the " "search path for user code. Only searching the top level under a directory. " "':' is used as the separator."); @@ -38,10 +42,14 @@ ABSL_FLAG(std::string, ray_job_id, "", "Assigned job id."); ABSL_FLAG(int32_t, ray_node_manager_port, 0, "The port to use for the node manager."); -ABSL_FLAG(std::string, ray_raylet_socket_name, "", +ABSL_FLAG(std::string, + ray_raylet_socket_name, + "", "It will specify the socket name used by the raylet if provided."); -ABSL_FLAG(std::string, ray_plasma_store_socket_name, "", +ABSL_FLAG(std::string, + ray_plasma_store_socket_name, + "", "It will specify the socket name used by the plasma store if provided."); ABSL_FLAG(std::string, ray_session_dir, "", "The path of this session."); @@ -50,12 +58,16 @@ ABSL_FLAG(std::string, ray_logs_dir, "", "Logs dir for workers."); ABSL_FLAG(std::string, ray_node_ip_address, "", "The ip address for this node."); -ABSL_FLAG(std::string, ray_head_args, "", +ABSL_FLAG(std::string, + ray_head_args, + "", "The command line args to be appended as parameters of the `ray start` " "command. It takes effect only if Ray head is started by a driver. Run `ray " "start --help` for details."); -ABSL_FLAG(int64_t, startup_token, -1, +ABSL_FLAG(int64_t, + startup_token, + -1, "The startup token assigned to this worker process by the raylet."); namespace ray { @@ -81,8 +93,8 @@ void ConfigInternal::Init(RayConfig &config, int argc, char **argv) { if (!FLAGS_ray_code_search_path.CurrentValue().empty()) { // Code search path like this "/path1/xxx.so:/path2". - code_search_path = absl::StrSplit(FLAGS_ray_code_search_path.CurrentValue(), ':', - absl::SkipEmpty()); + code_search_path = absl::StrSplit( + FLAGS_ray_code_search_path.CurrentValue(), ':', absl::SkipEmpty()); } if (!FLAGS_ray_address.CurrentValue().empty()) { SetBootstrapAddress(FLAGS_ray_address.CurrentValue()); @@ -150,8 +162,8 @@ void ConfigInternal::SetBootstrapAddress(std::string_view address) { auto pos = address.find(':'); RAY_CHECK(pos != std::string::npos); bootstrap_ip = address.substr(0, pos); - auto ret = std::from_chars(address.data() + pos + 1, address.data() + address.size(), - bootstrap_port); + auto ret = std::from_chars( + address.data() + pos + 1, address.data() + address.size(), bootstrap_port); RAY_CHECK(ret.ec == std::errc()); } } // namespace internal diff --git a/cpp/src/ray/runtime/abstract_ray_runtime.cc b/cpp/src/ray/runtime/abstract_ray_runtime.cc index ff9c5dbdb..c3f6d9b7c 100644 --- a/cpp/src/ray/runtime/abstract_ray_runtime.cc +++ b/cpp/src/ray/runtime/abstract_ray_runtime.cc @@ -111,7 +111,8 @@ std::vector> AbstractRayRuntime::Get( } std::vector AbstractRayRuntime::Wait(const std::vector &ids, - int num_objects, int timeout_ms) { + int num_objects, + int timeout_ms) { return object_store_->Wait(StringIDsToObjectIDs(ids), num_objects, timeout_ms); } @@ -129,7 +130,8 @@ std::vector> TransformArgs( auto meta_str = arg.meta_str; metadata = std::make_shared( reinterpret_cast(const_cast(meta_str.data())), - meta_str.size(), true); + meta_str.size(), + true); } ray_arg = absl::make_unique(std::make_shared( memory_buffer, metadata, std::vector())); @@ -141,7 +143,8 @@ std::vector> TransformArgs( auto &core_worker = CoreWorkerProcess::GetCoreWorker(); owner_address = core_worker.GetOwnerAddress(id); } - ray_arg = absl::make_unique(id, owner_address, + ray_arg = absl::make_unique(id, + owner_address, /*call_site=*/""); } ray_args.push_back(std::move(ray_arg)); @@ -181,8 +184,10 @@ std::string AbstractRayRuntime::CreateActor( } std::string AbstractRayRuntime::CallActor( - const RemoteFunctionHolder &remote_function_holder, const std::string &actor, - std::vector &args, const CallOptions &call_options) { + const RemoteFunctionHolder &remote_function_holder, + const std::string &actor, + std::vector &args, + const CallOptions &call_options) { InvocationSpec invocation_spec{}; if (remote_function_holder.lang_type == LangType::PYTHON) { const auto native_actor_handle = CoreWorkerProcess::GetCoreWorker().GetActorHandle( @@ -192,11 +197,11 @@ std::string AbstractRayRuntime::CallActor( RemoteFunctionHolder func_holder = remote_function_holder; func_holder.module_name = typed_descriptor->ModuleName(); func_holder.class_name = typed_descriptor->ClassName(); - invocation_spec = BuildInvocationSpec1(TaskType::ACTOR_TASK, func_holder, args, - ActorID::FromBinary(actor)); + invocation_spec = BuildInvocationSpec1( + TaskType::ACTOR_TASK, func_holder, args, ActorID::FromBinary(actor)); } else { - invocation_spec = BuildInvocationSpec1(TaskType::ACTOR_TASK, remote_function_holder, - args, ActorID::FromBinary(actor)); + invocation_spec = BuildInvocationSpec1( + TaskType::ACTOR_TASK, remote_function_holder, args, ActorID::FromBinary(actor)); } return task_submitter_->SubmitActorTask(invocation_spec, call_options).Binary(); @@ -308,7 +313,8 @@ PlacementGroup AbstractRayRuntime::GeneratePlacementGroup(const std::string &str options.bundles.emplace_back(bundle); } options.strategy = PlacementStrategy(pg_table_data.strategy()); - PlacementGroup group(pg_table_data.placement_group_id(), std::move(options), + PlacementGroup group(pg_table_data.placement_group_id(), + std::move(options), PlacementGroupState(pg_table_data.state())); return group; } diff --git a/cpp/src/ray/runtime/abstract_ray_runtime.h b/cpp/src/ray/runtime/abstract_ray_runtime.h index ef4468466..d9467c2c1 100644 --- a/cpp/src/ray/runtime/abstract_ray_runtime.h +++ b/cpp/src/ray/runtime/abstract_ray_runtime.h @@ -54,7 +54,8 @@ class AbstractRayRuntime : public RayRuntime { std::vector> Get(const std::vector &ids); - std::vector Wait(const std::vector &ids, int num_objects, + std::vector Wait(const std::vector &ids, + int num_objects, int timeout_ms); std::string Call(const RemoteFunctionHolder &remote_function_holder, diff --git a/cpp/src/ray/runtime/local_mode_ray_runtime.cc b/cpp/src/ray/runtime/local_mode_ray_runtime.cc index 6197f8ec7..47e9c9dc1 100644 --- a/cpp/src/ray/runtime/local_mode_ray_runtime.cc +++ b/cpp/src/ray/runtime/local_mode_ray_runtime.cc @@ -24,7 +24,8 @@ namespace ray { namespace internal { LocalModeRayRuntime::LocalModeRayRuntime() - : worker_(ray::core::WorkerType::DRIVER, ComputeDriverIdFromJob(JobID::Nil()), + : worker_(ray::core::WorkerType::DRIVER, + ComputeDriverIdFromJob(JobID::Nil()), JobID::Nil()) { object_store_ = std::unique_ptr(new LocalModeObjectStore(*this)); task_submitter_ = std::unique_ptr(new LocalModeTaskSubmitter(*this)); diff --git a/cpp/src/ray/runtime/logging.cc b/cpp/src/ray/runtime/logging.cc index 93f8576ce..81ea1b936 100644 --- a/cpp/src/ray/runtime/logging.cc +++ b/cpp/src/ray/runtime/logging.cc @@ -26,7 +26,8 @@ class RayLoggerImpl : public RayLogger, public ray::RayLog { std::ostream &Stream() override { return ray::RayLog::Stream(); } }; -std::unique_ptr CreateRayLogger(const char *file_name, int line_number, +std::unique_ptr CreateRayLogger(const char *file_name, + int line_number, RayLoggerLevel severity) { return std::make_unique(file_name, line_number, severity); } diff --git a/cpp/src/ray/runtime/object/local_mode_object_store.cc b/cpp/src/ray/runtime/object/local_mode_object_store.cc index d28b7b74a..fdc06bc14 100644 --- a/cpp/src/ray/runtime/object/local_mode_object_store.cc +++ b/cpp/src/ray/runtime/object/local_mode_object_store.cc @@ -58,9 +58,12 @@ std::shared_ptr LocalModeObjectStore::GetRaw(const ObjectID &o std::vector> LocalModeObjectStore::GetRaw( const std::vector &ids, int timeout_ms) { std::vector> results; - ::ray::Status status = - memory_store_->Get(ids, (int)ids.size(), timeout_ms, - local_mode_ray_tuntime_.GetWorkerContext(), false, &results); + ::ray::Status status = memory_store_->Get(ids, + (int)ids.size(), + timeout_ms, + local_mode_ray_tuntime_.GetWorkerContext(), + false, + &results); if (!status.ok()) { throw RayException("Get object error: " + status.ToString()); } @@ -78,15 +81,18 @@ std::vector> LocalModeObjectStore::GetRaw( } std::vector LocalModeObjectStore::Wait(const std::vector &ids, - int num_objects, int timeout_ms) { + int num_objects, + int timeout_ms) { absl::flat_hash_set memory_object_ids; for (const auto &object_id : ids) { memory_object_ids.insert(object_id); } absl::flat_hash_set ready; - ::ray::Status status = - memory_store_->Wait(memory_object_ids, num_objects, timeout_ms, - local_mode_ray_tuntime_.GetWorkerContext(), &ready); + ::ray::Status status = memory_store_->Wait(memory_object_ids, + num_objects, + timeout_ms, + local_mode_ray_tuntime_.GetWorkerContext(), + &ready); if (!status.ok()) { throw RayException("Wait object error: " + status.ToString()); } diff --git a/cpp/src/ray/runtime/object/local_mode_object_store.h b/cpp/src/ray/runtime/object/local_mode_object_store.h index 6f027f780..a2f4eb86c 100644 --- a/cpp/src/ray/runtime/object/local_mode_object_store.h +++ b/cpp/src/ray/runtime/object/local_mode_object_store.h @@ -29,7 +29,8 @@ class LocalModeObjectStore : public ObjectStore { public: LocalModeObjectStore(LocalModeRayRuntime &local_mode_ray_tuntime); - std::vector Wait(const std::vector &ids, int num_objects, + std::vector Wait(const std::vector &ids, + int num_objects, int timeout_ms); void AddLocalReference(const std::string &id); diff --git a/cpp/src/ray/runtime/object/native_object_store.cc b/cpp/src/ray/runtime/object/native_object_store.cc index 7add3b72b..022f77d51 100644 --- a/cpp/src/ray/runtime/object/native_object_store.cc +++ b/cpp/src/ray/runtime/object/native_object_store.cc @@ -34,7 +34,8 @@ void NativeObjectStore::PutRaw(std::shared_ptr data, auto buffer = std::make_shared<::ray::LocalMemoryBuffer>( reinterpret_cast(data->data()), data->size(), true); auto status = core_worker.Put( - ::ray::RayObject(buffer, nullptr, std::vector()), {}, + ::ray::RayObject(buffer, nullptr, std::vector()), + {}, object_id); if (!status.ok()) { throw RayException("Put object error"); @@ -48,7 +49,8 @@ void NativeObjectStore::PutRaw(std::shared_ptr data, auto buffer = std::make_shared<::ray::LocalMemoryBuffer>( reinterpret_cast(data->data()), data->size(), true); auto status = core_worker.Put( - ::ray::RayObject(buffer, nullptr, std::vector()), {}, + ::ray::RayObject(buffer, nullptr, std::vector()), + {}, object_id); if (!status.ok()) { throw RayException("Put object error"); @@ -113,7 +115,8 @@ std::vector> NativeObjectStore::GetRaw( } std::vector NativeObjectStore::Wait(const std::vector &ids, - int num_objects, int timeout_ms) { + int num_objects, + int timeout_ms) { std::vector results; auto &core_worker = CoreWorkerProcess::GetCoreWorker(); // TODO(SongGuyang): Support `fetch_local` option in API. diff --git a/cpp/src/ray/runtime/object/native_object_store.h b/cpp/src/ray/runtime/object/native_object_store.h index 0a1956d9f..e4cecf8fa 100644 --- a/cpp/src/ray/runtime/object/native_object_store.h +++ b/cpp/src/ray/runtime/object/native_object_store.h @@ -24,7 +24,8 @@ namespace internal { class NativeObjectStore : public ObjectStore { public: - std::vector Wait(const std::vector &ids, int num_objects, + std::vector Wait(const std::vector &ids, + int num_objects, int timeout_ms); void AddLocalReference(const std::string &id); diff --git a/cpp/src/ray/runtime/object/object_store.h b/cpp/src/ray/runtime/object/object_store.h index 227cb1e59..c56480833 100644 --- a/cpp/src/ray/runtime/object/object_store.h +++ b/cpp/src/ray/runtime/object/object_store.h @@ -18,6 +18,7 @@ #include #include + #include "ray/common/id.h" namespace ray { @@ -67,7 +68,8 @@ class ObjectStore { /// \param[in] num_objects The minimum number of objects to wait. /// \param[in] timeout_ms The maximum wait time in milliseconds. /// \return A vector that indicates each object has appeared or not. - virtual std::vector Wait(const std::vector &ids, int num_objects, + virtual std::vector Wait(const std::vector &ids, + int num_objects, int timeout_ms) = 0; /// Increase the reference count for this object ID. diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc index d68db2bd6..297d158fb 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc @@ -49,27 +49,41 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, invocation.name.empty() ? functionDescriptor->DefaultTaskName() : invocation.name; // TODO (Alex): Properly set the depth here? - builder.SetCommonTaskSpec(task_id, task_name, rpc::Language::CPP, functionDescriptor, + builder.SetCommonTaskSpec(task_id, + task_name, + rpc::Language::CPP, + functionDescriptor, local_mode_ray_tuntime_.GetCurrentJobID(), - local_mode_ray_tuntime_.GetCurrentTaskId(), 0, - local_mode_ray_tuntime_.GetCurrentTaskId(), address, 1, - required_resources, required_placement_resources, "", + local_mode_ray_tuntime_.GetCurrentTaskId(), + 0, + local_mode_ray_tuntime_.GetCurrentTaskId(), + address, + 1, + required_resources, + required_placement_resources, + "", /*depth=*/0); if (invocation.task_type == TaskType::NORMAL_TASK) { } else if (invocation.task_type == TaskType::ACTOR_CREATION_TASK) { invocation.actor_id = local_mode_ray_tuntime_.GetNextActorID(); rpc::SchedulingStrategy scheduling_strategy; scheduling_strategy.mutable_default_scheduling_strategy(); - builder.SetActorCreationTaskSpec(invocation.actor_id, /*serialized_actor_handle=*/"", - scheduling_strategy, options.max_restarts, - /*max_task_retries=*/0, {}, options.max_concurrency); + builder.SetActorCreationTaskSpec(invocation.actor_id, + /*serialized_actor_handle=*/"", + scheduling_strategy, + options.max_restarts, + /*max_task_retries=*/0, + {}, + options.max_concurrency); } else if (invocation.task_type == TaskType::ACTOR_TASK) { const TaskID actor_creation_task_id = TaskID::ForActorCreationTask(invocation.actor_id); const ObjectID actor_creation_dummy_object_id = ObjectID::FromIndex(actor_creation_task_id, 1); - builder.SetActorTaskSpec(invocation.actor_id, actor_creation_dummy_object_id, - ObjectID(), invocation.actor_counter); + builder.SetActorTaskSpec(invocation.actor_id, + actor_creation_dummy_object_id, + ObjectID(), + invocation.actor_counter); } else { throw RayException("unknown task type"); } @@ -92,20 +106,20 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, /// TODO(SongGuyang): Handle task dependencies. /// Execute actor task directly in the main thread because we must guarantee the actor /// task executed by calling order. - TaskExecutor::Invoke(task_specification, actor, runtime, actor_contexts_, - actor_contexts_mutex_); + TaskExecutor::Invoke( + task_specification, actor, runtime, actor_contexts_, actor_contexts_mutex_); } else { - boost::asio::post(*thread_pool_.get(), - std::bind( - [actor, mutex, runtime, this](TaskSpecification &ts) { - if (mutex) { - absl::MutexLock lock(mutex.get()); - } - TaskExecutor::Invoke(ts, actor, runtime, - this->actor_contexts_, - this->actor_contexts_mutex_); - }, - std::move(task_specification))); + boost::asio::post( + *thread_pool_.get(), + std::bind( + [actor, mutex, runtime, this](TaskSpecification &ts) { + if (mutex) { + absl::MutexLock lock(mutex.get()); + } + TaskExecutor::Invoke( + ts, actor, runtime, this->actor_contexts_, this->actor_contexts_mutex_); + }, + std::move(task_specification))); } return return_object_id; } diff --git a/cpp/src/ray/runtime/task/native_task_submitter.cc b/cpp/src/ray/runtime/task/native_task_submitter.cc index 4e6127444..d4457c12d 100644 --- a/cpp/src/ray/runtime/task/native_task_submitter.cc +++ b/cpp/src/ray/runtime/task/native_task_submitter.cc @@ -33,7 +33,8 @@ RayFunction BuildRayFunction(InvocationSpec &invocation) { auto function_descriptor = FunctionDescriptorBuilder::BuildPython( invocation.remote_function_holder.module_name, invocation.remote_function_holder.class_name, - invocation.remote_function_holder.function_name, ""); + invocation.remote_function_holder.function_name, + ""); return RayFunction(ray::Language::PYTHON, function_descriptor); } else { throw RayException("not supported yet"); @@ -76,8 +77,13 @@ ObjectID NativeTaskSubmitter::Submit(InvocationSpec &invocation, bundle_id.second); placement_group_scheduling_strategy->set_placement_group_capture_child_tasks(false); } - return_refs = core_worker.SubmitTask(BuildRayFunction(invocation), invocation.args, - options, 1, false, scheduling_strategy, ""); + return_refs = core_worker.SubmitTask(BuildRayFunction(invocation), + invocation.args, + options, + 1, + false, + scheduling_strategy, + ""); } std::vector return_ids; for (const auto &ref : return_refs.value()) { @@ -121,8 +127,8 @@ ActorID NativeTaskSubmitter::CreateActor(InvocationSpec &invocation, /*is_asyncio=*/false, scheduling_strategy}; ActorID actor_id; - auto status = core_worker.CreateActor(BuildRayFunction(invocation), invocation.args, - actor_options, "", &actor_id); + auto status = core_worker.CreateActor( + BuildRayFunction(invocation), invocation.args, actor_options, "", &actor_id); if (!status.ok()) { throw RayException("Create actor error"); } @@ -150,8 +156,10 @@ ActorID NativeTaskSubmitter::GetActor(const std::string &actor_name) const { ray::PlacementGroup NativeTaskSubmitter::CreatePlacementGroup( const ray::PlacementGroupCreationOptions &create_options) { auto options = ray::core::PlacementGroupCreationOptions( - create_options.name, (ray::core::PlacementStrategy)create_options.strategy, - create_options.bundles, false); + create_options.name, + (ray::core::PlacementStrategy)create_options.strategy, + create_options.bundles, + false); ray::PlacementGroupID placement_group_id; auto status = CoreWorkerProcess::GetCoreWorker().CreatePlacementGroup( options, &placement_group_id); diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index 348442584..e731cc929 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -85,7 +85,8 @@ std::unique_ptr TaskExecutor::Execute(InvocationSpec &invocation) { /// TODO(qicosmos): Need to add more details of the error messages, such as object id, /// task id etc. std::pair> GetExecuteResult( - const std::string &func_name, const ArgsBufferList &args_buffer, + const std::string &func_name, + const ArgsBufferList &args_buffer, msgpack::sbuffer *actor_ptr) { try { EntryFuntion entry_function; @@ -122,11 +123,14 @@ std::pair> GetExecuteResult( } Status TaskExecutor::ExecuteTask( - ray::TaskType task_type, const std::string task_name, const RayFunction &ray_function, + ray::TaskType task_type, + const std::string task_name, + const RayFunction &ray_function, const std::unordered_map &required_resources, const std::vector> &args_buffer, const std::vector &arg_refs, - const std::vector &return_ids, const std::string &debugger_breakpoint, + const std::vector &return_ids, + const std::string &debugger_breakpoint, std::vector> *results, std::shared_ptr &creation_task_exception_pb_bytes, bool *is_application_level_error, @@ -210,8 +214,12 @@ Status TaskExecutor::ExecuteTask( size_t total = cross_lang ? (XLANG_HEADER_LEN + data_size) : data_size; RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().AllocateReturnObject( - result_id, total, meta_buffer, std::vector(), - &task_output_inlined_bytes, result_ptr)); + result_id, + total, + meta_buffer, + std::vector(), + &task_output_inlined_bytes, + result_ptr)); auto result = *result_ptr; if (result != nullptr) { @@ -243,7 +251,8 @@ Status TaskExecutor::ExecuteTask( } void TaskExecutor::Invoke( - const TaskSpecification &task_spec, std::shared_ptr actor, + const TaskSpecification &task_spec, + std::shared_ptr actor, AbstractRayRuntime *runtime, std::unordered_map> &actor_contexts, absl::Mutex &actor_contexts_mutex) { @@ -267,8 +276,8 @@ void TaskExecutor::Invoke( std::shared_ptr data; try { if (actor) { - auto result = TaskExecutionHandler(typed_descriptor->FunctionName(), args_buffer, - actor.get()); + auto result = TaskExecutionHandler( + typed_descriptor->FunctionName(), args_buffer, actor.get()); data = std::make_shared(std::move(result)); runtime->Put(std::move(data), task_spec.ReturnId(0)); } else { diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index c7b9789e6..ae414f607 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -71,18 +71,21 @@ class TaskExecutor { std::unique_ptr Execute(InvocationSpec &invocation); static void Invoke( - const TaskSpecification &task_spec, std::shared_ptr actor, + const TaskSpecification &task_spec, + std::shared_ptr actor, AbstractRayRuntime *runtime, std::unordered_map> &actor_contexts, absl::Mutex &actor_contexts_mutex); static Status ExecuteTask( - ray::TaskType task_type, const std::string task_name, + ray::TaskType task_type, + const std::string task_name, const RayFunction &ray_function, const std::unordered_map &required_resources, const std::vector> &args, const std::vector &arg_refs, - const std::vector &return_ids, const std::string &debugger_breakpoint, + const std::vector &return_ids, + const std::string &debugger_breakpoint, std::vector> *results, std::shared_ptr &creation_task_exception_pb_bytes, bool *is_application_level_error, diff --git a/cpp/src/ray/test/api_test.cc b/cpp/src/ray/test/api_test.cc index 03e4deaef..c4d401689 100644 --- a/cpp/src/ray/test/api_test.cc +++ b/cpp/src/ray/test/api_test.cc @@ -98,8 +98,14 @@ class Counter { } }; -RAY_REMOTE(Counter::FactoryCreate, &Counter::Plus1, &Counter::Plus, &Counter::Triple, - &Counter::Add, &Counter::GetVal, &Counter::GetIntVal, &Counter::GetList); +RAY_REMOTE(Counter::FactoryCreate, + &Counter::Plus1, + &Counter::Plus, + &Counter::Triple, + &Counter::Add, + &Counter::GetVal, + &Counter::GetIntVal, + &Counter::GetList); TEST(RayApiTest, LogTest) { auto log_path = boost::filesystem::current_path().string() + "/tmp/"; @@ -325,8 +331,8 @@ TEST(RayApiTest, CompareWithFuture) { TEST(RayApiTest, CreateAndRemovePlacementGroup) { std::vector> bundles{{{"CPU", 1}}}; - ray::PlacementGroupCreationOptions options1{"first_placement_group", bundles, - ray::PlacementStrategy::PACK}; + ray::PlacementGroupCreationOptions options1{ + "first_placement_group", bundles, ray::PlacementStrategy::PACK}; auto first_placement_group = ray::CreatePlacementGroup(options1); EXPECT_TRUE(first_placement_group.Wait(10)); diff --git a/cpp/src/ray/test/cluster/cluster_mode_test.cc b/cpp/src/ray/test/cluster/cluster_mode_test.cc index 881a5560a..6e84d3ee6 100644 --- a/cpp/src/ray/test/cluster/cluster_mode_test.cc +++ b/cpp/src/ray/test/cluster/cluster_mode_test.cc @@ -45,8 +45,8 @@ struct Person { TEST(RayClusterModeTest, FullTest) { ray::RayConfig config; - config.head_args = {"--num-cpus", "2", "--resources", - "{\"resource1\":1,\"resource2\":2}"}; + config.head_args = { + "--num-cpus", "2", "--resources", "{\"resource1\":1,\"resource2\":2}"}; if (absl::GetFlag(FLAGS_external_cluster)) { auto port = absl::GetFlag(FLAGS_redis_port); std::string password = absl::GetFlag(FLAGS_redis_password); @@ -401,8 +401,8 @@ TEST(RayClusterModeTest, CreateAndRemovePlacementGroup) { TEST(RayClusterModeTest, CreatePlacementGroupExceedsClusterResource) { std::vector> bundles{{{"CPU", 10000}}}; - ray::PlacementGroupCreationOptions options{"first_placement_group", bundles, - ray::PlacementStrategy::PACK}; + ray::PlacementGroupCreationOptions options{ + "first_placement_group", bundles, ray::PlacementStrategy::PACK}; auto first_placement_group = ray::CreatePlacementGroup(options); EXPECT_FALSE(first_placement_group.Wait(3)); ray::RemovePlacementGroup(first_placement_group.GetID()); diff --git a/cpp/src/ray/test/cluster/counter.cc b/cpp/src/ray/test/cluster/counter.cc index 479775cd0..a2695d7bd 100644 --- a/cpp/src/ray/test/cluster/counter.cc +++ b/cpp/src/ray/test/cluster/counter.cc @@ -83,11 +83,18 @@ bool Counter::CheckRestartInActorCreationTask() { return is_restared; } bool Counter::CheckRestartInActorTask() { return ray::WasCurrentActorRestarted(); } -RAY_REMOTE(RAY_FUNC(Counter::FactoryCreate), Counter::FactoryCreateException, +RAY_REMOTE(RAY_FUNC(Counter::FactoryCreate), + Counter::FactoryCreateException, RAY_FUNC(Counter::FactoryCreate, int), - RAY_FUNC(Counter::FactoryCreate, int, int), &Counter::Plus1, &Counter::Add, - &Counter::Exit, &Counter::GetPid, &Counter::ExceptionFunc, - &Counter::CheckRestartInActorCreationTask, &Counter::CheckRestartInActorTask, - &Counter::GetVal, &Counter::GetIntVal); + RAY_FUNC(Counter::FactoryCreate, int, int), + &Counter::Plus1, + &Counter::Add, + &Counter::Exit, + &Counter::GetPid, + &Counter::ExceptionFunc, + &Counter::CheckRestartInActorCreationTask, + &Counter::CheckRestartInActorTask, + &Counter::GetVal, + &Counter::GetIntVal); RAY_REMOTE(ActorConcurrentCall::FactoryCreate, &ActorConcurrentCall::CountDown); diff --git a/cpp/src/ray/test/cluster/plus.cc b/cpp/src/ray/test/cluster/plus.cc index 345a054e3..8f537ccd7 100644 --- a/cpp/src/ray/test/cluster/plus.cc +++ b/cpp/src/ray/test/cluster/plus.cc @@ -38,5 +38,15 @@ Student GetStudent(Student student) { return student; } std::map GetStudents(std::map students) { return students; } -RAY_REMOTE(Return1, Plus1, Plus, ThrowTask, ReturnLargeArray, Echo, GetMap, GetArray, - GetList, GetTuple, GetStudent, GetStudents); +RAY_REMOTE(Return1, + Plus1, + Plus, + ThrowTask, + ReturnLargeArray, + Echo, + GetMap, + GetArray, + GetList, + GetTuple, + GetStudent, + GetStudents); diff --git a/cpp/src/ray/test/examples/simple_kv_store.cc b/cpp/src/ray/test/examples/simple_kv_store.cc index cd32a4689..93d79593b 100644 --- a/cpp/src/ray/test/examples/simple_kv_store.cc +++ b/cpp/src/ray/test/examples/simple_kv_store.cc @@ -151,8 +151,8 @@ void StartServer() { // different nodes if possible. std::vector> bundles{RESOUECES, RESOUECES}; - ray::PlacementGroupCreationOptions options{"kv_server_pg", bundles, - ray::PlacementStrategy::SPREAD}; + ray::PlacementGroupCreationOptions options{ + "kv_server_pg", bundles, ray::PlacementStrategy::SPREAD}; auto placement_group = ray::CreatePlacementGroup(options); // Wait until the placement group is created. assert(placement_group.Wait(10)); diff --git a/cpp/src/ray/util/process_helper.cc b/cpp/src/ray/util/process_helper.cc index 3a58ea3b8..66b7e26f0 100644 --- a/cpp/src/ray/util/process_helper.cc +++ b/cpp/src/ray/util/process_helper.cc @@ -27,11 +27,18 @@ namespace internal { using ray::core::CoreWorkerProcess; using ray::core::WorkerType; -void ProcessHelper::StartRayNode(const int redis_port, const std::string redis_password, +void ProcessHelper::StartRayNode(const int redis_port, + const std::string redis_password, const std::vector &head_args) { - std::vector cmdargs( - {"ray", "start", "--head", "--port", std::to_string(redis_port), "--redis-password", - redis_password, "--node-ip-address", GetNodeIpAddress()}); + std::vector cmdargs({"ray", + "start", + "--head", + "--port", + std::to_string(redis_port), + "--redis-password", + redis_password, + "--node-ip-address", + GetNodeIpAddress()}); if (!head_args.empty()) { cmdargs.insert(cmdargs.end(), head_args.begin(), head_args.end()); } @@ -63,8 +70,8 @@ std::unique_ptr ProcessHelper::CreateGlobalStateA std::vector address; boost::split(address, redis_address, boost::is_any_of(":")); RAY_CHECK(address.size() == 2); - ray::gcs::GcsClientOptions client_options(address[0], std::stoi(address[1]), - redis_password); + ray::gcs::GcsClientOptions client_options( + address[0], std::stoi(address[1]), redis_password); auto global_state_accessor = std::make_unique(client_options); @@ -79,7 +86,8 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback) if (ConfigInternal::Instance().worker_type == WorkerType::DRIVER && bootstrap_ip.empty()) { bootstrap_ip = "127.0.0.1"; - StartRayNode(bootstrap_port, ConfigInternal::Instance().redis_password, + StartRayNode(bootstrap_port, + ConfigInternal::Instance().redis_password, ConfigInternal::Instance().head_args); } if (bootstrap_ip == "127.0.0.1") { @@ -129,7 +137,8 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback) gcs::GcsClientOptions gcs_options = ::RayConfig::instance().bootstrap_with_gcs() ? gcs::GcsClientOptions(bootstrap_address) - : gcs::GcsClientOptions(bootstrap_ip, ConfigInternal::Instance().bootstrap_port, + : gcs::GcsClientOptions(bootstrap_ip, + ConfigInternal::Instance().bootstrap_port, ConfigInternal::Instance().redis_password); CoreWorkerOptions options; diff --git a/cpp/src/ray/util/process_helper.h b/cpp/src/ray/util/process_helper.h index 6e17a2e05..d83f7fe8e 100644 --- a/cpp/src/ray/util/process_helper.h +++ b/cpp/src/ray/util/process_helper.h @@ -29,7 +29,8 @@ class ProcessHelper { public: void RayStart(CoreWorkerOptions::TaskExecutionCallback callback); void RayStop(); - void StartRayNode(const int redis_port, const std::string redis_password, + void StartRayNode(const int redis_port, + const std::string redis_password, const std::vector &head_args = {}); void StopRayNode(); diff --git a/cpp/src/ray/util/util.cc b/cpp/src/ray/util/util.cc index 7cc00d416..ad3a00cd7 100644 --- a/cpp/src/ray/util/util.cc +++ b/cpp/src/ray/util/util.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "util.h" + #include #include + #include "ray/util/logging.h" namespace ray { @@ -27,8 +29,8 @@ std::string GetNodeIpAddress(const std::string &address) { try { boost::asio::io_service netService; boost::asio::ip::udp::resolver resolver(netService); - boost::asio::ip::udp::resolver::query query(boost::asio::ip::udp::v4(), parts[0], - parts[1]); + boost::asio::ip::udp::resolver::query query( + boost::asio::ip::udp::v4(), parts[0], parts[1]); boost::asio::ip::udp::resolver::iterator endpoints = resolver.resolve(query); boost::asio::ip::udp::endpoint ep = *endpoints; boost::asio::ip::udp::socket socket(netService); diff --git a/cpp/src/ray/worker/default_worker.cc b/cpp/src/ray/worker/default_worker.cc index ebb1a4fcf..11eb69c71 100644 --- a/cpp/src/ray/worker/default_worker.cc +++ b/cpp/src/ray/worker/default_worker.cc @@ -14,10 +14,10 @@ #include #include -#include "ray/core_worker/common.h" -#include "ray/core_worker/core_worker.h" #include "../config_internal.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/core_worker.h" int main(int argc, char **argv) { RAY_LOG(INFO) << "CPP default worker started."; diff --git a/src/mock/ray/core_worker/actor_creator.h b/src/mock/ray/core_worker/actor_creator.h index 305fb8146..23e3e3e7e 100644 --- a/src/mock/ray/core_worker/actor_creator.h +++ b/src/mock/ray/core_worker/actor_creator.h @@ -19,17 +19,23 @@ namespace core { class MockActorCreatorInterface : public ActorCreatorInterface { public: - MOCK_METHOD(Status, RegisterActor, (const TaskSpecification &task_spec), + MOCK_METHOD(Status, + RegisterActor, + (const TaskSpecification &task_spec), (const, override)); - MOCK_METHOD(Status, AsyncRegisterActor, + MOCK_METHOD(Status, + AsyncRegisterActor, (const TaskSpecification &task_spec, gcs::StatusCallback callback), (override)); - MOCK_METHOD(Status, AsyncCreateActor, + MOCK_METHOD(Status, + AsyncCreateActor, (const TaskSpecification &task_spec, const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, AsyncWaitForActorRegisterFinish, - (const ActorID &actor_id, gcs::StatusCallback callback), (override)); + MOCK_METHOD(void, + AsyncWaitForActorRegisterFinish, + (const ActorID &actor_id, gcs::StatusCallback callback), + (override)); MOCK_METHOD(bool, IsActorInRegistering, (const ActorID &actor_id), (const, override)); }; @@ -41,13 +47,17 @@ namespace core { class MockDefaultActorCreator : public DefaultActorCreator { public: - MOCK_METHOD(Status, RegisterActor, (const TaskSpecification &task_spec), + MOCK_METHOD(Status, + RegisterActor, + (const TaskSpecification &task_spec), (const, override)); - MOCK_METHOD(Status, AsyncRegisterActor, + MOCK_METHOD(Status, + AsyncRegisterActor, (const TaskSpecification &task_spec, gcs::StatusCallback callback), (override)); MOCK_METHOD(bool, IsActorInRegistering, (const ActorID &actor_id), (const, override)); - MOCK_METHOD(Status, AsyncCreateActor, + MOCK_METHOD(Status, + AsyncCreateActor, (const TaskSpecification &task_spec, const rpc::ClientCallback &callback), (override)); diff --git a/src/mock/ray/core_worker/core_worker.h b/src/mock/ray/core_worker/core_worker.h index a078a1a32..49736faf2 100644 --- a/src/mock/ray/core_worker/core_worker.h +++ b/src/mock/ray/core_worker/core_worker.h @@ -39,105 +39,134 @@ namespace core { class MockCoreWorker : public CoreWorker { public: - MOCK_METHOD(void, HandlePushTask, - (const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, + MOCK_METHOD(void, + HandlePushTask, + (const rpc::PushTaskRequest &request, + rpc::PushTaskReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleDirectActorCallArgWaitComplete, + MOCK_METHOD(void, + HandleDirectActorCallArgWaitComplete, (const rpc::DirectActorCallArgWaitCompleteRequest &request, rpc::DirectActorCallArgWaitCompleteReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetObjectStatus, + MOCK_METHOD(void, + HandleGetObjectStatus, (const rpc::GetObjectStatusRequest &request, rpc::GetObjectStatusReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleWaitForActorOutOfScope, + MOCK_METHOD(void, + HandleWaitForActorOutOfScope, (const rpc::WaitForActorOutOfScopeRequest &request, rpc::WaitForActorOutOfScopeReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandlePubsubLongPolling, + MOCK_METHOD(void, + HandlePubsubLongPolling, (const rpc::PubsubLongPollingRequest &request, rpc::PubsubLongPollingReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandlePubsubCommandBatch, + MOCK_METHOD(void, + HandlePubsubCommandBatch, (const rpc::PubsubCommandBatchRequest &request, rpc::PubsubCommandBatchReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleAddObjectLocationOwner, + MOCK_METHOD(void, + HandleAddObjectLocationOwner, (const rpc::AddObjectLocationOwnerRequest &request, rpc::AddObjectLocationOwnerReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleRemoveObjectLocationOwner, + MOCK_METHOD(void, + HandleRemoveObjectLocationOwner, (const rpc::RemoveObjectLocationOwnerRequest &request, rpc::RemoveObjectLocationOwnerReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetObjectLocationsOwner, + MOCK_METHOD(void, + HandleGetObjectLocationsOwner, (const rpc::GetObjectLocationsOwnerRequest &request, rpc::GetObjectLocationsOwnerReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleKillActor, - (const rpc::KillActorRequest &request, rpc::KillActorReply *reply, + MOCK_METHOD(void, + HandleKillActor, + (const rpc::KillActorRequest &request, + rpc::KillActorReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleCancelTask, - (const rpc::CancelTaskRequest &request, rpc::CancelTaskReply *reply, + MOCK_METHOD(void, + HandleCancelTask, + (const rpc::CancelTaskRequest &request, + rpc::CancelTaskReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleRemoteCancelTask, + MOCK_METHOD(void, + HandleRemoteCancelTask, (const rpc::RemoteCancelTaskRequest &request, rpc::RemoteCancelTaskReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandlePlasmaObjectReady, + MOCK_METHOD(void, + HandlePlasmaObjectReady, (const rpc::PlasmaObjectReadyRequest &request, rpc::PlasmaObjectReadyReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetCoreWorkerStats, + MOCK_METHOD(void, + HandleGetCoreWorkerStats, (const rpc::GetCoreWorkerStatsRequest &request, rpc::GetCoreWorkerStatsReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleLocalGC, - (const rpc::LocalGCRequest &request, rpc::LocalGCReply *reply, + MOCK_METHOD(void, + HandleLocalGC, + (const rpc::LocalGCRequest &request, + rpc::LocalGCReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleRunOnUtilWorker, + MOCK_METHOD(void, + HandleRunOnUtilWorker, (const rpc::RunOnUtilWorkerRequest &request, rpc::RunOnUtilWorkerReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleSpillObjects, - (const rpc::SpillObjectsRequest &request, rpc::SpillObjectsReply *reply, + MOCK_METHOD(void, + HandleSpillObjects, + (const rpc::SpillObjectsRequest &request, + rpc::SpillObjectsReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleAddSpilledUrl, - (const rpc::AddSpilledUrlRequest &request, rpc::AddSpilledUrlReply *reply, + MOCK_METHOD(void, + HandleAddSpilledUrl, + (const rpc::AddSpilledUrlRequest &request, + rpc::AddSpilledUrlReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleRestoreSpilledObjects, + MOCK_METHOD(void, + HandleRestoreSpilledObjects, (const rpc::RestoreSpilledObjectsRequest &request, rpc::RestoreSpilledObjectsReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleDeleteSpilledObjects, + MOCK_METHOD(void, + HandleDeleteSpilledObjects, (const rpc::DeleteSpilledObjectsRequest &request, rpc::DeleteSpilledObjectsReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleExit, - (const rpc::ExitRequest &request, rpc::ExitReply *reply, + MOCK_METHOD(void, + HandleExit, + (const rpc::ExitRequest &request, + rpc::ExitReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleAssignObjectOwner, + MOCK_METHOD(void, + HandleAssignObjectOwner, (const rpc::AssignObjectOwnerRequest &request, rpc::AssignObjectOwnerReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/core_worker/lease_policy.h b/src/mock/ray/core_worker/lease_policy.h index 2f0b0d672..6a0282797 100644 --- a/src/mock/ray/core_worker/lease_policy.h +++ b/src/mock/ray/core_worker/lease_policy.h @@ -27,7 +27,9 @@ namespace core { class MockLocalityDataProviderInterface : public LocalityDataProviderInterface { public: - MOCK_METHOD(absl::optional, GetLocalityData, (const ObjectID &object_id), + MOCK_METHOD(absl::optional, + GetLocalityData, + (const ObjectID &object_id), (override)); }; @@ -39,8 +41,10 @@ namespace core { class MockLeasePolicyInterface : public LeasePolicyInterface { public: - MOCK_METHOD((std::pair), GetBestNodeForTask, - (const TaskSpecification &spec), (override)); + MOCK_METHOD((std::pair), + GetBestNodeForTask, + (const TaskSpecification &spec), + (override)); }; } // namespace core diff --git a/src/mock/ray/core_worker/task_manager.h b/src/mock/ray/core_worker/task_manager.h index f5b64ad77..426800dbf 100644 --- a/src/mock/ray/core_worker/task_manager.h +++ b/src/mock/ray/core_worker/task_manager.h @@ -19,28 +19,43 @@ namespace core { class MockTaskFinisherInterface : public TaskFinisherInterface { public: - MOCK_METHOD(void, CompletePendingTask, - (const TaskID &task_id, const rpc::PushTaskReply &reply, + MOCK_METHOD(void, + CompletePendingTask, + (const TaskID &task_id, + const rpc::PushTaskReply &reply, const rpc::Address &actor_addr), (override)); - MOCK_METHOD(void, FailPendingTask, - (const TaskID &task_id, rpc::ErrorType error_type, const Status *status, - const rpc::RayErrorInfo *ray_error_info, bool mark_task_object_failed), + MOCK_METHOD(void, + FailPendingTask, + (const TaskID &task_id, + rpc::ErrorType error_type, + const Status *status, + const rpc::RayErrorInfo *ray_error_info, + bool mark_task_object_failed), (override)); - MOCK_METHOD(bool, FailOrRetryPendingTask, - (const TaskID &task_id, rpc::ErrorType error_type, const Status *status, - const rpc::RayErrorInfo *ray_error_info, bool mark_task_object_failed), + MOCK_METHOD(bool, + FailOrRetryPendingTask, + (const TaskID &task_id, + rpc::ErrorType error_type, + const Status *status, + const rpc::RayErrorInfo *ray_error_info, + bool mark_task_object_failed), (override)); - MOCK_METHOD(void, OnTaskDependenciesInlined, + MOCK_METHOD(void, + OnTaskDependenciesInlined, (const std::vector &inlined_dependency_ids, const std::vector &contained_ids), (override)); MOCK_METHOD(bool, MarkTaskCanceled, (const TaskID &task_id), (override)); - MOCK_METHOD(void, MarkTaskReturnObjectsFailed, - (const TaskSpecification &spec, rpc::ErrorType error_type, + MOCK_METHOD(void, + MarkTaskReturnObjectsFailed, + (const TaskSpecification &spec, + rpc::ErrorType error_type, const rpc::RayErrorInfo *ray_error_info), (override)); - MOCK_METHOD(absl::optional, GetTaskSpec, (const TaskID &task_id), + MOCK_METHOD(absl::optional, + GetTaskSpec, + (const TaskID &task_id), (const, override)); MOCK_METHOD(bool, RetryTaskIfPossible, (const TaskID &task_id), (override)); MOCK_METHOD(void, MarkDependenciesResolved, (const TaskID &task_id), (override)); @@ -54,8 +69,10 @@ namespace core { class MockTaskResubmissionInterface : public TaskResubmissionInterface { public: - MOCK_METHOD(bool, ResubmitTask, - (const TaskID &task_id, std::vector *task_deps), (override)); + MOCK_METHOD(bool, + ResubmitTask, + (const TaskID &task_id, std::vector *task_deps), + (override)); }; } // namespace core diff --git a/src/mock/ray/core_worker/transport/direct_actor_transport.h b/src/mock/ray/core_worker/transport/direct_actor_transport.h index 5225851a6..1904a3517 100644 --- a/src/mock/ray/core_worker/transport/direct_actor_transport.h +++ b/src/mock/ray/core_worker/transport/direct_actor_transport.h @@ -17,24 +17,34 @@ namespace core { class MockCoreWorkerDirectActorTaskSubmitterInterface : public CoreWorkerDirectActorTaskSubmitterInterface { public: - MOCK_METHOD(void, AddActorQueueIfNotExists, - (const ActorID &actor_id, int32_t max_pending_calls), (override)); - MOCK_METHOD(void, ConnectActor, - (const ActorID &actor_id, const rpc::Address &address, + MOCK_METHOD(void, + AddActorQueueIfNotExists, + (const ActorID &actor_id, int32_t max_pending_calls), + (override)); + MOCK_METHOD(void, + ConnectActor, + (const ActorID &actor_id, + const rpc::Address &address, int64_t num_restarts), (override)); - MOCK_METHOD(void, DisconnectActor, - (const ActorID &actor_id, int64_t num_restarts, bool dead, + MOCK_METHOD(void, + DisconnectActor, + (const ActorID &actor_id, + int64_t num_restarts, + bool dead, const rpc::RayException *creation_task_exception), (override)); - MOCK_METHOD(void, KillActor, - (const ActorID &actor_id, bool force_kill, bool no_restart), (override)); + MOCK_METHOD(void, + KillActor, + (const ActorID &actor_id, bool force_kill, bool no_restart), + (override)); MOCK_METHOD(void, CheckTimeoutTasks, (), (override)); }; class MockDependencyWaiter : public DependencyWaiter { public: - MOCK_METHOD(void, Wait, + MOCK_METHOD(void, + Wait, (const std::vector &dependencies, std::function on_dependencies_available), (override)); @@ -42,13 +52,16 @@ class MockDependencyWaiter : public DependencyWaiter { class MockSchedulingQueue : public SchedulingQueue { public: - MOCK_METHOD(void, Add, - (int64_t seq_no, int64_t client_processed_up_to, + MOCK_METHOD(void, + Add, + (int64_t seq_no, + int64_t client_processed_up_to, std::function accept_request, std::function reject_request, rpc::SendReplyCallback send_reply_callback, const std::string &concurrency_group_name, - const ray::FunctionDescriptor &function_descriptor, TaskID task_id, + const ray::FunctionDescriptor &function_descriptor, + TaskID task_id, const std::vector &dependencies), (override)); MOCK_METHOD(void, ScheduleRequests, (), (override)); diff --git a/src/mock/ray/gcs/gcs_client/accessor.h b/src/mock/ray/gcs/gcs_client/accessor.h index 29c197140..d0586b213 100644 --- a/src/mock/ray/gcs/gcs_client/accessor.h +++ b/src/mock/ray/gcs/gcs_client/accessor.h @@ -20,37 +20,53 @@ namespace gcs { class MockActorInfoAccessor : public ActorInfoAccessor { public: - MOCK_METHOD(Status, AsyncGet, + MOCK_METHOD(Status, + AsyncGet, (const ActorID &actor_id, const OptionalItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGetAll, - (const MultiItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGetByName, - (const std::string &name, const std::string &ray_namespace, + MOCK_METHOD(Status, + AsyncGetAll, + (const MultiItemCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncGetByName, + (const std::string &name, + const std::string &ray_namespace, const OptionalItemCallback &callback, int64_t timeout_ms), (override)); - MOCK_METHOD(Status, AsyncListNamedActors, - (bool all_namespaces, const std::string &ray_namespace, + MOCK_METHOD(Status, + AsyncListNamedActors, + (bool all_namespaces, + const std::string &ray_namespace, const OptionalItemCallback> &callback, int64_t timeout_ms), (override)); - MOCK_METHOD(Status, AsyncRegisterActor, - (const TaskSpecification &task_spec, const StatusCallback &callback, + MOCK_METHOD(Status, + AsyncRegisterActor, + (const TaskSpecification &task_spec, + const StatusCallback &callback, int64_t timeout_ms), (override)); - MOCK_METHOD(Status, SyncRegisterActor, (const TaskSpecification &task_spec), + MOCK_METHOD(Status, + SyncRegisterActor, + (const TaskSpecification &task_spec), (override)); - MOCK_METHOD(Status, AsyncKillActor, - (const ActorID &actor_id, bool force_kill, bool no_restart, + MOCK_METHOD(Status, + AsyncKillActor, + (const ActorID &actor_id, + bool force_kill, + bool no_restart, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncCreateActor, + MOCK_METHOD(Status, + AsyncCreateActor, (const TaskSpecification &task_spec, const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(Status, AsyncSubscribe, + MOCK_METHOD(Status, + AsyncSubscribe, (const ActorID &actor_id, (const SubscribeCallback &subscribe), const StatusCallback &done), @@ -68,20 +84,28 @@ namespace gcs { class MockJobInfoAccessor : public JobInfoAccessor { public: - MOCK_METHOD(Status, AsyncAdd, + MOCK_METHOD(Status, + AsyncAdd, (const std::shared_ptr &data_ptr, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncMarkFinished, - (const JobID &job_id, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncSubscribeAll, + MOCK_METHOD(Status, + AsyncMarkFinished, + (const JobID &job_id, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncSubscribeAll, ((const SubscribeCallback &subscribe), const StatusCallback &done), (override)); - MOCK_METHOD(Status, AsyncGetAll, (const MultiItemCallback &callback), + MOCK_METHOD(Status, + AsyncGetAll, + (const MultiItemCallback &callback), (override)); MOCK_METHOD(void, AsyncResubscribe, (bool is_pubsub_server_restarted), (override)); - MOCK_METHOD(Status, AsyncGetNextJobID, (const ItemCallback &callback), + MOCK_METHOD(Status, + AsyncGetNextJobID, + (const ItemCallback &callback), (override)); }; @@ -93,35 +117,49 @@ namespace gcs { class MockNodeInfoAccessor : public NodeInfoAccessor { public: - MOCK_METHOD(Status, RegisterSelf, + MOCK_METHOD(Status, + RegisterSelf, (const rpc::GcsNodeInfo &local_node_info, const StatusCallback &callback), (override)); MOCK_METHOD(Status, DrainSelf, (), (override)); MOCK_METHOD(const NodeID &, GetSelfId, (), (const, override)); MOCK_METHOD(const rpc::GcsNodeInfo &, GetSelfInfo, (), (const, override)); - MOCK_METHOD(Status, AsyncRegister, + MOCK_METHOD(Status, + AsyncRegister, (const rpc::GcsNodeInfo &node_info, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncDrainNode, - (const NodeID &node_id, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGetAll, (const MultiItemCallback &callback), + MOCK_METHOD(Status, + AsyncDrainNode, + (const NodeID &node_id, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncSubscribeToNodeChange, + MOCK_METHOD(Status, + AsyncGetAll, + (const MultiItemCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncSubscribeToNodeChange, ((const SubscribeCallback &subscribe), const StatusCallback &done), (override)); - MOCK_METHOD(const rpc::GcsNodeInfo *, Get, - (const NodeID &node_id, bool filter_dead_nodes), (const, override)); - MOCK_METHOD((const absl::flat_hash_map &), GetAll, (), + MOCK_METHOD(const rpc::GcsNodeInfo *, + Get, + (const NodeID &node_id, bool filter_dead_nodes), + (const, override)); + MOCK_METHOD((const absl::flat_hash_map &), + GetAll, + (), (const, override)); MOCK_METHOD(bool, IsRemoved, (const NodeID &node_id), (const, override)); - MOCK_METHOD(Status, AsyncReportHeartbeat, + MOCK_METHOD(Status, + AsyncReportHeartbeat, (const std::shared_ptr &data_ptr, const StatusCallback &callback), (override)); MOCK_METHOD(void, AsyncResubscribe, (bool is_pubsub_server_restarted), (override)); - MOCK_METHOD(Status, AsyncGetInternalConfig, - (const OptionalItemCallback &callback), (override)); + MOCK_METHOD(Status, + AsyncGetInternalConfig, + (const OptionalItemCallback &callback), + (override)); }; } // namespace gcs @@ -132,24 +170,32 @@ namespace gcs { class MockNodeResourceInfoAccessor : public NodeResourceInfoAccessor { public: - MOCK_METHOD(Status, AsyncGetResources, + MOCK_METHOD(Status, + AsyncGetResources, (const NodeID &node_id, const OptionalItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGetAllAvailableResources, - (const MultiItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncSubscribeToResources, + MOCK_METHOD(Status, + AsyncGetAllAvailableResources, + (const MultiItemCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncSubscribeToResources, (const ItemCallback &subscribe, const StatusCallback &done), (override)); MOCK_METHOD(void, AsyncResubscribe, (bool is_pubsub_server_restarted), (override)); - MOCK_METHOD(Status, AsyncReportResourceUsage, + MOCK_METHOD(Status, + AsyncReportResourceUsage, (const std::shared_ptr &data_ptr, const StatusCallback &callback), (override)); MOCK_METHOD(void, AsyncReReportResourceUsage, (), (override)); - MOCK_METHOD(Status, AsyncGetAllResourceUsage, - (const ItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncSubscribeBatchedResourceUsage, + MOCK_METHOD(Status, + AsyncGetAllResourceUsage, + (const ItemCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncSubscribeBatchedResourceUsage, (const ItemCallback &subscribe, const StatusCallback &done), (override)); @@ -163,7 +209,8 @@ namespace gcs { class MockErrorInfoAccessor : public ErrorInfoAccessor { public: - MOCK_METHOD(Status, AsyncReportJobError, + MOCK_METHOD(Status, + AsyncReportJobError, (const std::shared_ptr &data_ptr, const StatusCallback &callback), (override)); @@ -177,12 +224,15 @@ namespace gcs { class MockStatsInfoAccessor : public StatsInfoAccessor { public: - MOCK_METHOD(Status, AsyncAddProfileData, + MOCK_METHOD(Status, + AsyncAddProfileData, (const std::shared_ptr &data_ptr, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGetAll, - (const MultiItemCallback &callback), (override)); + MOCK_METHOD(Status, + AsyncGetAll, + (const MultiItemCallback &callback), + (override)); }; } // namespace gcs @@ -193,21 +243,27 @@ namespace gcs { class MockWorkerInfoAccessor : public WorkerInfoAccessor { public: - MOCK_METHOD(Status, AsyncSubscribeToWorkerFailures, + MOCK_METHOD(Status, + AsyncSubscribeToWorkerFailures, (const ItemCallback &subscribe, const StatusCallback &done), (override)); - MOCK_METHOD(Status, AsyncReportWorkerFailure, + MOCK_METHOD(Status, + AsyncReportWorkerFailure, (const std::shared_ptr &data_ptr, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGet, + MOCK_METHOD(Status, + AsyncGet, (const WorkerID &worker_id, const OptionalItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGetAll, - (const MultiItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncAdd, + MOCK_METHOD(Status, + AsyncGetAll, + (const MultiItemCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncAdd, (const std::shared_ptr &data_ptr, const StatusCallback &callback), (override)); @@ -222,23 +278,33 @@ namespace gcs { class MockPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor { public: - MOCK_METHOD(Status, SyncCreatePlacementGroup, - (const PlacementGroupSpecification &placement_group_spec), (override)); - MOCK_METHOD(Status, AsyncGet, + MOCK_METHOD(Status, + SyncCreatePlacementGroup, + (const PlacementGroupSpecification &placement_group_spec), + (override)); + MOCK_METHOD(Status, + AsyncGet, (const PlacementGroupID &placement_group_id, const OptionalItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGetByName, - (const std::string &placement_group_name, const std::string &ray_namespace, + MOCK_METHOD(Status, + AsyncGetByName, + (const std::string &placement_group_name, + const std::string &ray_namespace, const OptionalItemCallback &callback, int64_t timeout_ms), (override)); - MOCK_METHOD(Status, AsyncGetAll, + MOCK_METHOD(Status, + AsyncGetAll, (const MultiItemCallback &callback), (override)); - MOCK_METHOD(Status, SyncRemovePlacementGroup, - (const PlacementGroupID &placement_group_id), (override)); - MOCK_METHOD(Status, SyncWaitUntilReady, (const PlacementGroupID &placement_group_id), + MOCK_METHOD(Status, + SyncRemovePlacementGroup, + (const PlacementGroupID &placement_group_id), + (override)); + MOCK_METHOD(Status, + SyncWaitUntilReady, + (const PlacementGroupID &placement_group_id), (override)); }; @@ -250,24 +316,37 @@ namespace gcs { class MockInternalKVAccessor : public InternalKVAccessor { public: - MOCK_METHOD(Status, AsyncInternalKVKeys, - (const std::string &ns, const std::string &prefix, + MOCK_METHOD(Status, + AsyncInternalKVKeys, + (const std::string &ns, + const std::string &prefix, const OptionalItemCallback> &callback), (override)); - MOCK_METHOD(Status, AsyncInternalKVGet, - (const std::string &ns, const std::string &key, + MOCK_METHOD(Status, + AsyncInternalKVGet, + (const std::string &ns, + const std::string &key, const OptionalItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncInternalKVPut, - (const std::string &ns, const std::string &key, const std::string &value, - bool overwrite, const OptionalItemCallback &callback), + MOCK_METHOD(Status, + AsyncInternalKVPut, + (const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, + const OptionalItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncInternalKVExists, - (const std::string &ns, const std::string &key, + MOCK_METHOD(Status, + AsyncInternalKVExists, + (const std::string &ns, + const std::string &key, const OptionalItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncInternalKVDel, - (const std::string &ns, const std::string &key, bool del_by_prefix, + MOCK_METHOD(Status, + AsyncInternalKVDel, + (const std::string &ns, + const std::string &key, + bool del_by_prefix, const StatusCallback &callback), (override)); }; diff --git a/src/mock/ray/gcs/gcs_server/gcs_actor_manager.h b/src/mock/ray/gcs/gcs_server/gcs_actor_manager.h index 48e0a949f..ebb516697 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_actor_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_actor_manager.h @@ -27,34 +27,44 @@ namespace gcs { class MockGcsActorManager : public GcsActorManager { public: - MOCK_METHOD(void, HandleRegisterActor, - (const rpc::RegisterActorRequest &request, rpc::RegisterActorReply *reply, + MOCK_METHOD(void, + HandleRegisterActor, + (const rpc::RegisterActorRequest &request, + rpc::RegisterActorReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleCreateActor, - (const rpc::CreateActorRequest &request, rpc::CreateActorReply *reply, + MOCK_METHOD(void, + HandleCreateActor, + (const rpc::CreateActorRequest &request, + rpc::CreateActorReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetActorInfo, - (const rpc::GetActorInfoRequest &request, rpc::GetActorInfoReply *reply, + MOCK_METHOD(void, + HandleGetActorInfo, + (const rpc::GetActorInfoRequest &request, + rpc::GetActorInfoReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetNamedActorInfo, + MOCK_METHOD(void, + HandleGetNamedActorInfo, (const rpc::GetNamedActorInfoRequest &request, rpc::GetNamedActorInfoReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleListNamedActors, + MOCK_METHOD(void, + HandleListNamedActors, (const rpc::ListNamedActorsRequest &request, rpc::ListNamedActorsReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetAllActorInfo, + MOCK_METHOD(void, + HandleGetAllActorInfo, (const rpc::GetAllActorInfoRequest &request, rpc::GetAllActorInfoReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleKillActorViaGcs, + MOCK_METHOD(void, + HandleKillActorViaGcs, (const rpc::KillActorViaGcsRequest &request, rpc::KillActorViaGcsReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/gcs_server/gcs_actor_scheduler.h b/src/mock/ray/gcs/gcs_server/gcs_actor_scheduler.h index 63e61aece..c2e626a33 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_actor_scheduler.h +++ b/src/mock/ray/gcs/gcs_server/gcs_actor_scheduler.h @@ -20,13 +20,17 @@ class MockGcsActorSchedulerInterface : public GcsActorSchedulerInterface { MOCK_METHOD(void, Schedule, (std::shared_ptr actor), (override)); MOCK_METHOD(void, Reschedule, (std::shared_ptr actor), (override)); MOCK_METHOD(std::vector, CancelOnNode, (const NodeID &node_id), (override)); - MOCK_METHOD(void, CancelOnLeasing, + MOCK_METHOD(void, + CancelOnLeasing, (const NodeID &node_id, const ActorID &actor_id, const TaskID &task_id), (override)); - MOCK_METHOD(ActorID, CancelOnWorker, (const NodeID &node_id, const WorkerID &worker_id), + MOCK_METHOD(ActorID, + CancelOnWorker, + (const NodeID &node_id, const WorkerID &worker_id), (override)); MOCK_METHOD( - void, ReleaseUnusedWorkers, + void, + ReleaseUnusedWorkers, ((const std::unordered_map> &node_to_workers)), (override)); }; @@ -42,25 +46,36 @@ class MockGcsActorScheduler : public GcsActorScheduler { MOCK_METHOD(void, Schedule, (std::shared_ptr actor), (override)); MOCK_METHOD(void, Reschedule, (std::shared_ptr actor), (override)); MOCK_METHOD(std::vector, CancelOnNode, (const NodeID &node_id), (override)); - MOCK_METHOD(void, CancelOnLeasing, + MOCK_METHOD(void, + CancelOnLeasing, (const NodeID &node_id, const ActorID &actor_id, const TaskID &task_id), (override)); - MOCK_METHOD(ActorID, CancelOnWorker, (const NodeID &node_id, const WorkerID &worker_id), + MOCK_METHOD(ActorID, + CancelOnWorker, + (const NodeID &node_id, const WorkerID &worker_id), (override)); MOCK_METHOD( - void, ReleaseUnusedWorkers, + void, + ReleaseUnusedWorkers, ((const std::unordered_map> &node_to_workers)), (override)); - MOCK_METHOD(std::shared_ptr, SelectNode, - (std::shared_ptr actor), (override)); - MOCK_METHOD(void, HandleWorkerLeaseReply, - (std::shared_ptr actor, std::shared_ptr node, - const Status &status, const rpc::RequestWorkerLeaseReply &reply), + MOCK_METHOD(std::shared_ptr, + SelectNode, + (std::shared_ptr actor), (override)); - MOCK_METHOD(void, RetryLeasingWorkerFromNode, + MOCK_METHOD(void, + HandleWorkerLeaseReply, + (std::shared_ptr actor, + std::shared_ptr node, + const Status &status, + const rpc::RequestWorkerLeaseReply &reply), + (override)); + MOCK_METHOD(void, + RetryLeasingWorkerFromNode, (std::shared_ptr actor, std::shared_ptr node), (override)); - MOCK_METHOD(void, RetryCreatingActorOnWorker, + MOCK_METHOD(void, + RetryCreatingActorOnWorker, (std::shared_ptr actor, std::shared_ptr worker), (override)); }; @@ -73,11 +88,16 @@ namespace gcs { class MockRayletBasedActorScheduler : public RayletBasedActorScheduler { public: - MOCK_METHOD(std::shared_ptr, SelectNode, - (std::shared_ptr actor), (override)); - MOCK_METHOD(void, HandleWorkerLeaseReply, - (std::shared_ptr actor, std::shared_ptr node, - const Status &status, const rpc::RequestWorkerLeaseReply &reply), + MOCK_METHOD(std::shared_ptr, + SelectNode, + (std::shared_ptr actor), + (override)); + MOCK_METHOD(void, + HandleWorkerLeaseReply, + (std::shared_ptr actor, + std::shared_ptr node, + const Status &status, + const rpc::RequestWorkerLeaseReply &reply), (override)); }; diff --git a/src/mock/ray/gcs/gcs_server/gcs_heartbeat_manager.h b/src/mock/ray/gcs/gcs_server/gcs_heartbeat_manager.h index 1c0933116..8b95db06c 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_heartbeat_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_heartbeat_manager.h @@ -17,13 +17,16 @@ namespace gcs { class MockGcsHeartbeatManager : public GcsHeartbeatManager { public: - MOCK_METHOD(void, HandleReportHeartbeat, + MOCK_METHOD(void, + HandleReportHeartbeat, (const rpc::ReportHeartbeatRequest &request, rpc::ReportHeartbeatReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleCheckAlive, - (const rpc::CheckAliveRequest &request, rpc::CheckAliveReply *reply, + MOCK_METHOD(void, + HandleCheckAlive, + (const rpc::CheckAliveRequest &request, + rpc::CheckAliveReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); }; diff --git a/src/mock/ray/gcs/gcs_server/gcs_job_manager.h b/src/mock/ray/gcs/gcs_server/gcs_job_manager.h index e1b006085..919299144 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_job_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_job_manager.h @@ -17,29 +17,40 @@ namespace gcs { class MockGcsJobManager : public GcsJobManager { public: - MOCK_METHOD(void, HandleAddJob, - (const rpc::AddJobRequest &request, rpc::AddJobReply *reply, + MOCK_METHOD(void, + HandleAddJob, + (const rpc::AddJobRequest &request, + rpc::AddJobReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleMarkJobFinished, + MOCK_METHOD(void, + HandleMarkJobFinished, (const rpc::MarkJobFinishedRequest &request, rpc::MarkJobFinishedReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetAllJobInfo, - (const rpc::GetAllJobInfoRequest &request, rpc::GetAllJobInfoReply *reply, + MOCK_METHOD(void, + HandleGetAllJobInfo, + (const rpc::GetAllJobInfoRequest &request, + rpc::GetAllJobInfoReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleReportJobError, - (const rpc::ReportJobErrorRequest &request, rpc::ReportJobErrorReply *reply, + MOCK_METHOD(void, + HandleReportJobError, + (const rpc::ReportJobErrorRequest &request, + rpc::ReportJobErrorReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetNextJobID, - (const rpc::GetNextJobIDRequest &request, rpc::GetNextJobIDReply *reply, + MOCK_METHOD(void, + HandleGetNextJobID, + (const rpc::GetNextJobIDRequest &request, + rpc::GetNextJobIDReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, AddJobFinishedListener, - (std::function)> listener), (override)); + MOCK_METHOD(void, + AddJobFinishedListener, + (std::function)> listener), + (override)); }; } // namespace gcs diff --git a/src/mock/ray/gcs/gcs_server/gcs_kv_manager.h b/src/mock/ray/gcs/gcs_server/gcs_kv_manager.h index 66ec8760d..1eeba85bc 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_kv_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_kv_manager.h @@ -19,24 +19,37 @@ class MockInternalKVInterface : public ray::gcs::InternalKVInterface { public: MockInternalKVInterface() {} - MOCK_METHOD(void, Get, - (const std::string &ns, const std::string &key, + MOCK_METHOD(void, + Get, + (const std::string &ns, + const std::string &key, std::function)> callback), (override)); - MOCK_METHOD(void, Put, - (const std::string &ns, const std::string &key, const std::string &value, - bool overwrite, std::function callback), - (override)); - MOCK_METHOD(void, Del, - (const std::string &ns, const std::string &key, bool del_by_prefix, - std::function callback), - (override)); - MOCK_METHOD(void, Exists, - (const std::string &ns, const std::string &key, + MOCK_METHOD(void, + Put, + (const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, std::function callback), (override)); - MOCK_METHOD(void, Keys, - (const std::string &ns, const std::string &prefix, + MOCK_METHOD(void, + Del, + (const std::string &ns, + const std::string &key, + bool del_by_prefix, + std::function callback), + (override)); + MOCK_METHOD(void, + Exists, + (const std::string &ns, + const std::string &key, + std::function callback), + (override)); + MOCK_METHOD(void, + Keys, + (const std::string &ns, + const std::string &prefix, std::function)> callback), (override)); MOCK_METHOD(instrumented_io_context &, GetEventLoop, (), (override)); diff --git a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h index 8345633cc..8bae5ace2 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h @@ -18,19 +18,26 @@ namespace gcs { class MockGcsNodeManager : public GcsNodeManager { public: MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr, nullptr) {} - MOCK_METHOD(void, HandleRegisterNode, - (const rpc::RegisterNodeRequest &request, rpc::RegisterNodeReply *reply, + MOCK_METHOD(void, + HandleRegisterNode, + (const rpc::RegisterNodeRequest &request, + rpc::RegisterNodeReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleDrainNode, - (const rpc::DrainNodeRequest &request, rpc::DrainNodeReply *reply, + MOCK_METHOD(void, + HandleDrainNode, + (const rpc::DrainNodeRequest &request, + rpc::DrainNodeReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetAllNodeInfo, - (const rpc::GetAllNodeInfoRequest &request, rpc::GetAllNodeInfoReply *reply, + MOCK_METHOD(void, + HandleGetAllNodeInfo, + (const rpc::GetAllNodeInfoRequest &request, + rpc::GetAllNodeInfoReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetInternalConfig, + MOCK_METHOD(void, + HandleGetInternalConfig, (const rpc::GetInternalConfigRequest &request, rpc::GetInternalConfigReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/gcs_server/gcs_placement_group_manager.h b/src/mock/ray/gcs/gcs_server/gcs_placement_group_manager.h index f8aa0018f..40b7ea2a2 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_placement_group_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_placement_group_manager.h @@ -27,32 +27,38 @@ namespace gcs { class MockGcsPlacementGroupManager : public GcsPlacementGroupManager { public: - MOCK_METHOD(void, HandleCreatePlacementGroup, + MOCK_METHOD(void, + HandleCreatePlacementGroup, (const rpc::CreatePlacementGroupRequest &request, rpc::CreatePlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleRemovePlacementGroup, + MOCK_METHOD(void, + HandleRemovePlacementGroup, (const rpc::RemovePlacementGroupRequest &request, rpc::RemovePlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetPlacementGroup, + MOCK_METHOD(void, + HandleGetPlacementGroup, (const rpc::GetPlacementGroupRequest &request, rpc::GetPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetNamedPlacementGroup, + MOCK_METHOD(void, + HandleGetNamedPlacementGroup, (const rpc::GetNamedPlacementGroupRequest &request, rpc::GetNamedPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetAllPlacementGroup, + MOCK_METHOD(void, + HandleGetAllPlacementGroup, (const rpc::GetAllPlacementGroupRequest &request, rpc::GetAllPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleWaitPlacementGroupUntilReady, + MOCK_METHOD(void, + HandleWaitPlacementGroupUntilReady, (const rpc::WaitPlacementGroupUntilReadyRequest &request, rpc::WaitPlacementGroupUntilReadyReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h b/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h index cf6201355..ca45c9d35 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h +++ b/src/mock/ray/gcs/gcs_server/gcs_placement_group_scheduler.h @@ -28,25 +28,34 @@ namespace gcs { class MockGcsPlacementGroupSchedulerInterface : public GcsPlacementGroupSchedulerInterface { public: - MOCK_METHOD(void, ScheduleUnplacedBundles, + MOCK_METHOD(void, + ScheduleUnplacedBundles, (std::shared_ptr placement_group, PGSchedulingFailureCallback failure_callback, PGSchedulingSuccessfulCallback success_callback), (override)); MOCK_METHOD((absl::flat_hash_map>), - GetBundlesOnNode, (const NodeID &node_id), (override)); - MOCK_METHOD(void, DestroyPlacementGroupBundleResourcesIfExists, - (const PlacementGroupID &placement_group_id), (override)); - MOCK_METHOD(void, MarkScheduleCancelled, (const PlacementGroupID &placement_group_id), + GetBundlesOnNode, + (const NodeID &node_id), + (override)); + MOCK_METHOD(void, + DestroyPlacementGroupBundleResourcesIfExists, + (const PlacementGroupID &placement_group_id), + (override)); + MOCK_METHOD(void, + MarkScheduleCancelled, + (const PlacementGroupID &placement_group_id), (override)); MOCK_METHOD( - void, ReleaseUnusedBundles, + void, + ReleaseUnusedBundles, ((const absl::flat_hash_map> &node_to_bundles)), (override)); - MOCK_METHOD(void, Initialize, + MOCK_METHOD(void, + Initialize, ((const absl::flat_hash_map< - PlacementGroupID, std::vector>> - &group_to_bundles)), + PlacementGroupID, + std::vector>> &group_to_bundles)), (override)); }; @@ -69,7 +78,8 @@ namespace gcs { class MockGcsScheduleStrategy : public GcsScheduleStrategy { public: MOCK_METHOD( - ScheduleResult, Schedule, + ScheduleResult, + Schedule, (const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler), @@ -85,7 +95,8 @@ namespace gcs { class MockGcsPackStrategy : public GcsPackStrategy { public: MOCK_METHOD( - ScheduleResult, Schedule, + ScheduleResult, + Schedule, (const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler), @@ -101,7 +112,8 @@ namespace gcs { class MockGcsSpreadStrategy : public GcsSpreadStrategy { public: MOCK_METHOD( - ScheduleResult, Schedule, + ScheduleResult, + Schedule, (const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler), @@ -117,7 +129,8 @@ namespace gcs { class MockGcsStrictPackStrategy : public GcsStrictPackStrategy { public: MOCK_METHOD( - ScheduleResult, Schedule, + ScheduleResult, + Schedule, (const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler), @@ -133,7 +146,8 @@ namespace gcs { class MockGcsStrictSpreadStrategy : public GcsStrictSpreadStrategy { public: MOCK_METHOD( - ScheduleResult, Schedule, + ScheduleResult, + Schedule, (const std::vector> &bundles, const std::unique_ptr &context, GcsResourceScheduler &gcs_resource_scheduler), @@ -168,19 +182,27 @@ namespace gcs { class MockGcsPlacementGroupScheduler : public GcsPlacementGroupScheduler { public: - MOCK_METHOD(void, ScheduleUnplacedBundles, + MOCK_METHOD(void, + ScheduleUnplacedBundles, (std::shared_ptr placement_group, PGSchedulingFailureCallback failure_handler, PGSchedulingSuccessfulCallback success_handler), (override)); - MOCK_METHOD(void, DestroyPlacementGroupBundleResourcesIfExists, - (const PlacementGroupID &placement_group_id), (override)); - MOCK_METHOD(void, MarkScheduleCancelled, (const PlacementGroupID &placement_group_id), + MOCK_METHOD(void, + DestroyPlacementGroupBundleResourcesIfExists, + (const PlacementGroupID &placement_group_id), + (override)); + MOCK_METHOD(void, + MarkScheduleCancelled, + (const PlacementGroupID &placement_group_id), (override)); MOCK_METHOD((absl::flat_hash_map>), - GetBundlesOnNode, (const NodeID &node_id), (override)); + GetBundlesOnNode, + (const NodeID &node_id), + (override)); MOCK_METHOD( - void, ReleaseUnusedBundles, + void, + ReleaseUnusedBundles, ((const absl::flat_hash_map> &node_to_bundles)), (override)); }; diff --git a/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h b/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h index 6ffe273f7..44addab46 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h @@ -18,21 +18,26 @@ namespace gcs { class MockGcsResourceManager : public GcsResourceManager { public: using GcsResourceManager::GcsResourceManager; - MOCK_METHOD(void, HandleGetResources, - (const rpc::GetResourcesRequest &request, rpc::GetResourcesReply *reply, + MOCK_METHOD(void, + HandleGetResources, + (const rpc::GetResourcesRequest &request, + rpc::GetResourcesReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetAllAvailableResources, + MOCK_METHOD(void, + HandleGetAllAvailableResources, (const rpc::GetAllAvailableResourcesRequest &request, rpc::GetAllAvailableResourcesReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleReportResourceUsage, + MOCK_METHOD(void, + HandleReportResourceUsage, (const rpc::ReportResourceUsageRequest &request, rpc::ReportResourceUsageReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetAllResourceUsage, + MOCK_METHOD(void, + HandleGetAllResourceUsage, (const rpc::GetAllResourceUsageRequest &request, rpc::GetAllResourceUsageReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/gcs_server/gcs_resource_scheduler.h b/src/mock/ray/gcs/gcs_server/gcs_resource_scheduler.h index 73a6cc0a8..613ce1eb1 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_resource_scheduler.h +++ b/src/mock/ray/gcs/gcs_server/gcs_resource_scheduler.h @@ -17,7 +17,8 @@ namespace gcs { class MockNodeScorer : public NodeScorer { public: - MOCK_METHOD(double, Score, + MOCK_METHOD(double, + Score, (const ResourceSet &required_resources, const SchedulingResources &node_resources), (override)); @@ -31,7 +32,8 @@ namespace gcs { class MockLeastResourceScorer : public LeastResourceScorer { public: - MOCK_METHOD(double, Score, + MOCK_METHOD(double, + Score, (const ResourceSet &required_resources, const SchedulingResources &node_resources), (override)); diff --git a/src/mock/ray/gcs/gcs_server/gcs_table_storage.h b/src/mock/ray/gcs/gcs_server/gcs_table_storage.h index a61133f41..640b49136 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/mock/ray/gcs/gcs_server/gcs_table_storage.h @@ -18,13 +18,18 @@ namespace gcs { template class MockGcsTable : public GcsTable { public: - MOCK_METHOD(Status, Put, + MOCK_METHOD(Status, + Put, (const Key &key, const Data &value, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, Delete, (const Key &key, const StatusCallback &callback), + MOCK_METHOD(Status, + Delete, + (const Key &key, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, + BatchDelete, + (const std::vector &keys, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, BatchDelete, - (const std::vector &keys, const StatusCallback &callback), (override)); }; } // namespace gcs @@ -36,13 +41,18 @@ namespace gcs { template class MockGcsTableWithJobId : public GcsTableWithJobId { public: - MOCK_METHOD(Status, Put, + MOCK_METHOD(Status, + Put, (const Key &key, const Data &value, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, Delete, (const Key &key, const StatusCallback &callback), + MOCK_METHOD(Status, + Delete, + (const Key &key, const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, + BatchDelete, + (const std::vector &keys, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, BatchDelete, - (const std::vector &keys, const StatusCallback &callback), (override)); MOCK_METHOD(JobID, GetJobIdFromKey, (const Key &key), (override)); }; diff --git a/src/mock/ray/gcs/gcs_server/gcs_worker_manager.h b/src/mock/ray/gcs/gcs_server/gcs_worker_manager.h index 9755232a8..43c27d996 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_worker_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_worker_manager.h @@ -17,22 +17,28 @@ namespace gcs { class MockGcsWorkerManager : public GcsWorkerManager { public: - MOCK_METHOD(void, HandleReportWorkerFailure, + MOCK_METHOD(void, + HandleReportWorkerFailure, (const rpc::ReportWorkerFailureRequest &request, rpc::ReportWorkerFailureReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetWorkerInfo, - (const rpc::GetWorkerInfoRequest &request, rpc::GetWorkerInfoReply *reply, + MOCK_METHOD(void, + HandleGetWorkerInfo, + (const rpc::GetWorkerInfoRequest &request, + rpc::GetWorkerInfoReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetAllWorkerInfo, + MOCK_METHOD(void, + HandleGetAllWorkerInfo, (const rpc::GetAllWorkerInfoRequest &request, rpc::GetAllWorkerInfoReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleAddWorkerInfo, - (const rpc::AddWorkerInfoRequest &request, rpc::AddWorkerInfoReply *reply, + MOCK_METHOD(void, + HandleAddWorkerInfo, + (const rpc::AddWorkerInfoRequest &request, + rpc::AddWorkerInfoReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); }; diff --git a/src/mock/ray/gcs/gcs_server/stats_handler_impl.h b/src/mock/ray/gcs/gcs_server/stats_handler_impl.h index e404f41b1..e7242357b 100644 --- a/src/mock/ray/gcs/gcs_server/stats_handler_impl.h +++ b/src/mock/ray/gcs/gcs_server/stats_handler_impl.h @@ -17,11 +17,14 @@ namespace rpc { class MockDefaultStatsHandler : public DefaultStatsHandler { public: - MOCK_METHOD(void, HandleAddProfileData, - (const AddProfileDataRequest &request, AddProfileDataReply *reply, + MOCK_METHOD(void, + HandleAddProfileData, + (const AddProfileDataRequest &request, + AddProfileDataReply *reply, SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetAllProfileInfo, + MOCK_METHOD(void, + HandleGetAllProfileInfo, (const rpc::GetAllProfileInfoRequest &request, rpc::GetAllProfileInfoReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/gcs/pubsub/gcs_pub_sub.h b/src/mock/ray/gcs/pubsub/gcs_pub_sub.h index 21e500da0..14252da56 100644 --- a/src/mock/ray/gcs/pubsub/gcs_pub_sub.h +++ b/src/mock/ray/gcs/pubsub/gcs_pub_sub.h @@ -17,8 +17,11 @@ namespace gcs { class MockGcsPubSub : public GcsPubSub { public: - MOCK_METHOD(Status, Publish, - (const std::string &channel, const std::string &id, const std::string &data, + MOCK_METHOD(Status, + Publish, + (const std::string &channel, + const std::string &id, + const std::string &data, const StatusCallback &done), (override)); }; diff --git a/src/mock/ray/gcs/store_client/in_memory_store_client.h b/src/mock/ray/gcs/store_client/in_memory_store_client.h index 08af16a07..78201a61b 100644 --- a/src/mock/ray/gcs/store_client/in_memory_store_client.h +++ b/src/mock/ray/gcs/store_client/in_memory_store_client.h @@ -17,46 +17,68 @@ namespace gcs { class MockInMemoryStoreClient : public InMemoryStoreClient { public: - MOCK_METHOD(Status, AsyncPut, - (const std::string &table_name, const std::string &key, - const std::string &data, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncPutWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const std::string &data, + MOCK_METHOD(Status, + AsyncPut, + (const std::string &table_name, + const std::string &key, + const std::string &data, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGet, - (const std::string &table_name, const std::string &key, + MOCK_METHOD(Status, + AsyncPutWithIndex, + (const std::string &table_name, + const std::string &key, + const std::string &index_key, + const std::string &data, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncGet, + (const std::string &table_name, + const std::string &key, const OptionalItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGetByIndex, - (const std::string &table_name, const std::string &index_key, + MOCK_METHOD(Status, + AsyncGetByIndex, + (const std::string &table_name, + const std::string &index_key, (const MapCallback &callback)), (override)); - MOCK_METHOD(Status, AsyncGetAll, + MOCK_METHOD(Status, + AsyncGetAll, (const std::string &table_name, (const MapCallback &callback)), (override)); - MOCK_METHOD(Status, AsyncDelete, - (const std::string &table_name, const std::string &key, + MOCK_METHOD(Status, + AsyncDelete, + (const std::string &table_name, + const std::string &key, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncDeleteWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncBatchDelete, - (const std::string &table_name, const std::vector &keys, + MOCK_METHOD(Status, + AsyncDeleteWithIndex, + (const std::string &table_name, + const std::string &key, + const std::string &index_key, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, - (const std::string &table_name, const std::vector &keys, + MOCK_METHOD(Status, + AsyncBatchDelete, + (const std::string &table_name, + const std::vector &keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncBatchDeleteWithIndex, + (const std::string &table_name, + const std::vector &keys, const std::vector &index_keys, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncDeleteByIndex, - (const std::string &table_name, const std::string &index_key, + MOCK_METHOD(Status, + AsyncDeleteByIndex, + (const std::string &table_name, + const std::string &index_key, const StatusCallback &callback), (override)); MOCK_METHOD(int, GetNextJobID, (), (override)); diff --git a/src/mock/ray/gcs/store_client/redis_store_client.h b/src/mock/ray/gcs/store_client/redis_store_client.h index 153a69755..5401625e8 100644 --- a/src/mock/ray/gcs/store_client/redis_store_client.h +++ b/src/mock/ray/gcs/store_client/redis_store_client.h @@ -18,46 +18,68 @@ namespace gcs { class MockRedisStoreClient : public RedisStoreClient { public: MockRedisStoreClient() : RedisStoreClient(nullptr) {} - MOCK_METHOD(Status, AsyncPut, - (const std::string &table_name, const std::string &key, - const std::string &data, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncPutWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const std::string &data, + MOCK_METHOD(Status, + AsyncPut, + (const std::string &table_name, + const std::string &key, + const std::string &data, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGet, - (const std::string &table_name, const std::string &key, + MOCK_METHOD(Status, + AsyncPutWithIndex, + (const std::string &table_name, + const std::string &key, + const std::string &index_key, + const std::string &data, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncGet, + (const std::string &table_name, + const std::string &key, const OptionalItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGetByIndex, - (const std::string &table_name, const std::string &index_key, + MOCK_METHOD(Status, + AsyncGetByIndex, + (const std::string &table_name, + const std::string &index_key, (const MapCallback &callback)), (override)); - MOCK_METHOD(Status, AsyncGetAll, + MOCK_METHOD(Status, + AsyncGetAll, (const std::string &table_name, (const MapCallback &callback)), (override)); - MOCK_METHOD(Status, AsyncDelete, - (const std::string &table_name, const std::string &key, + MOCK_METHOD(Status, + AsyncDelete, + (const std::string &table_name, + const std::string &key, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncDeleteWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncBatchDelete, - (const std::string &table_name, const std::vector &keys, + MOCK_METHOD(Status, + AsyncDeleteWithIndex, + (const std::string &table_name, + const std::string &key, + const std::string &index_key, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, - (const std::string &table_name, const std::vector &keys, + MOCK_METHOD(Status, + AsyncBatchDelete, + (const std::string &table_name, + const std::vector &keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncBatchDeleteWithIndex, + (const std::string &table_name, + const std::vector &keys, const std::vector &index_keys, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncDeleteByIndex, - (const std::string &table_name, const std::string &index_key, + MOCK_METHOD(Status, + AsyncDeleteByIndex, + (const std::string &table_name, + const std::string &index_key, const StatusCallback &callback), (override)); MOCK_METHOD(int, GetNextJobID, (), (override)); diff --git a/src/mock/ray/gcs/store_client/store_client.h b/src/mock/ray/gcs/store_client/store_client.h index 6f4e3b538..6a92516f1 100644 --- a/src/mock/ray/gcs/store_client/store_client.h +++ b/src/mock/ray/gcs/store_client/store_client.h @@ -17,46 +17,68 @@ namespace gcs { class MockStoreClient : public StoreClient { public: - MOCK_METHOD(Status, AsyncPut, - (const std::string &table_name, const std::string &key, - const std::string &data, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncPutWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const std::string &data, + MOCK_METHOD(Status, + AsyncPut, + (const std::string &table_name, + const std::string &key, + const std::string &data, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGet, - (const std::string &table_name, const std::string &key, + MOCK_METHOD(Status, + AsyncPutWithIndex, + (const std::string &table_name, + const std::string &key, + const std::string &index_key, + const std::string &data, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncGet, + (const std::string &table_name, + const std::string &key, const OptionalItemCallback &callback), (override)); - MOCK_METHOD(Status, AsyncGetByIndex, - (const std::string &table_name, const std::string &index_key, + MOCK_METHOD(Status, + AsyncGetByIndex, + (const std::string &table_name, + const std::string &index_key, (const MapCallback &callback)), (override)); - MOCK_METHOD(Status, AsyncGetAll, + MOCK_METHOD(Status, + AsyncGetAll, (const std::string &table_name, (const MapCallback &callback)), (override)); - MOCK_METHOD(Status, AsyncDelete, - (const std::string &table_name, const std::string &key, + MOCK_METHOD(Status, + AsyncDelete, + (const std::string &table_name, + const std::string &key, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncDeleteWithIndex, - (const std::string &table_name, const std::string &key, - const std::string &index_key, const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, AsyncBatchDelete, - (const std::string &table_name, const std::vector &keys, + MOCK_METHOD(Status, + AsyncDeleteWithIndex, + (const std::string &table_name, + const std::string &key, + const std::string &index_key, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, - (const std::string &table_name, const std::vector &keys, + MOCK_METHOD(Status, + AsyncBatchDelete, + (const std::string &table_name, + const std::vector &keys, + const StatusCallback &callback), + (override)); + MOCK_METHOD(Status, + AsyncBatchDeleteWithIndex, + (const std::string &table_name, + const std::vector &keys, const std::vector &index_keys, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, AsyncDeleteByIndex, - (const std::string &table_name, const std::string &index_key, + MOCK_METHOD(Status, + AsyncDeleteByIndex, + (const std::string &table_name, + const std::string &index_key, const StatusCallback &callback), (override)); MOCK_METHOD(int, GetNextJobID, (), (override)); diff --git a/src/mock/ray/pubsub/publisher.h b/src/mock/ray/pubsub/publisher.h index e8b9fca5c..e445ede07 100644 --- a/src/mock/ray/pubsub/publisher.h +++ b/src/mock/ray/pubsub/publisher.h @@ -54,19 +54,26 @@ namespace pubsub { class MockPublisherInterface : public PublisherInterface { public: - MOCK_METHOD(bool, RegisterSubscription, - (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + MOCK_METHOD(bool, + RegisterSubscription, + (const rpc::ChannelType channel_type, + const SubscriberID &subscriber_id, const std::optional &key_id_binary), (override)); - MOCK_METHOD(void, Publish, - (const rpc::ChannelType channel_type, const rpc::PubMessage &pub_message, + MOCK_METHOD(void, + Publish, + (const rpc::ChannelType channel_type, + const rpc::PubMessage &pub_message, const std::string &key_id_binary), (override)); - MOCK_METHOD(void, PublishFailure, + MOCK_METHOD(void, + PublishFailure, (const rpc::ChannelType channel_type, const std::string &key_id_binary), (override)); - MOCK_METHOD(bool, UnregisterSubscription, - (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + MOCK_METHOD(bool, + UnregisterSubscription, + (const rpc::ChannelType channel_type, + const SubscriberID &subscriber_id, const std::optional &key_id_binary), (override)); }; @@ -79,19 +86,26 @@ namespace pubsub { class MockPublisher : public Publisher { public: - MOCK_METHOD(bool, RegisterSubscription, - (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + MOCK_METHOD(bool, + RegisterSubscription, + (const rpc::ChannelType channel_type, + const SubscriberID &subscriber_id, const std::optional &key_id_binary), (override)); - MOCK_METHOD(void, Publish, - (const rpc::ChannelType channel_type, const rpc::PubMessage &pub_message, + MOCK_METHOD(void, + Publish, + (const rpc::ChannelType channel_type, + const rpc::PubMessage &pub_message, const std::string &key_id_binary), (override)); - MOCK_METHOD(void, PublishFailure, + MOCK_METHOD(void, + PublishFailure, (const rpc::ChannelType channel_type, const std::string &key_id_binary), (override)); - MOCK_METHOD(bool, UnregisterSubscription, - (const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, + MOCK_METHOD(bool, + UnregisterSubscription, + (const rpc::ChannelType channel_type, + const SubscriberID &subscriber_id, const std::optional &key_id_binary), (override)); }; diff --git a/src/mock/ray/pubsub/subscriber.h b/src/mock/ray/pubsub/subscriber.h index 47271c41f..98d63828a 100644 --- a/src/mock/ray/pubsub/subscriber.h +++ b/src/mock/ray/pubsub/subscriber.h @@ -17,25 +17,33 @@ namespace pubsub { class MockSubscriberInterface : public SubscriberInterface { public: - MOCK_METHOD(bool, Subscribe, + MOCK_METHOD(bool, + Subscribe, (std::unique_ptr sub_message, - const rpc::ChannelType channel_type, const rpc::Address &publisher_address, - const std::string &key_id, SubscribeDoneCallback subscribe_done_callback, - SubscriptionItemCallback subscription_callback, - SubscriptionFailureCallback subscription_failure_callback), - (override)); - MOCK_METHOD(bool, SubscribeChannel, - (std::unique_ptr sub_message, - const rpc::ChannelType channel_type, const rpc::Address &publisher_address, + const rpc::ChannelType channel_type, + const rpc::Address &publisher_address, + const std::string &key_id, SubscribeDoneCallback subscribe_done_callback, SubscriptionItemCallback subscription_callback, SubscriptionFailureCallback subscription_failure_callback), (override)); - MOCK_METHOD(bool, Unsubscribe, - (const rpc::ChannelType channel_type, const rpc::Address &publisher_address, + MOCK_METHOD(bool, + SubscribeChannel, + (std::unique_ptr sub_message, + const rpc::ChannelType channel_type, + const rpc::Address &publisher_address, + SubscribeDoneCallback subscribe_done_callback, + SubscriptionItemCallback subscription_callback, + SubscriptionFailureCallback subscription_failure_callback), + (override)); + MOCK_METHOD(bool, + Unsubscribe, + (const rpc::ChannelType channel_type, + const rpc::Address &publisher_address, const std::string &key_id), (override)); - MOCK_METHOD(bool, UnsubscribeChannel, + MOCK_METHOD(bool, + UnsubscribeChannel, (const rpc::ChannelType channel_type, const rpc::Address &publisher_address), (override)); @@ -50,11 +58,13 @@ namespace pubsub { class MockSubscriberClientInterface : public SubscriberClientInterface { public: - MOCK_METHOD(void, PubsubLongPolling, + MOCK_METHOD(void, + PubsubLongPolling, (const rpc::PubsubLongPollingRequest &request, const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, PubsubCommandBatch, + MOCK_METHOD(void, + PubsubCommandBatch, (const rpc::PubsubCommandBatchRequest &request, const rpc::ClientCallback &callback), (override)); diff --git a/src/mock/ray/raylet/agent_manager.h b/src/mock/ray/raylet/agent_manager.h index 5dfebd5a9..9bcebda81 100644 --- a/src/mock/ray/raylet/agent_manager.h +++ b/src/mock/ray/raylet/agent_manager.h @@ -17,15 +17,20 @@ namespace raylet { class MockAgentManager : public AgentManager { public: - MOCK_METHOD(void, HandleRegisterAgent, - (const rpc::RegisterAgentRequest &request, rpc::RegisterAgentReply *reply, + MOCK_METHOD(void, + HandleRegisterAgent, + (const rpc::RegisterAgentRequest &request, + rpc::RegisterAgentReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, CreateRuntimeEnv, - (const JobID &job_id, const std::string &serialized_runtime_env, + MOCK_METHOD(void, + CreateRuntimeEnv, + (const JobID &job_id, + const std::string &serialized_runtime_env, CreateRuntimeEnvCallback callback), (override)); - MOCK_METHOD(void, DeleteRuntimeEnv, + MOCK_METHOD(void, + DeleteRuntimeEnv, (const std::string &serialized_runtime_env, DeleteRuntimeEnvCallback callback), (override)); @@ -39,8 +44,10 @@ namespace raylet { class MockDefaultAgentManagerServiceHandler : public DefaultAgentManagerServiceHandler { public: - MOCK_METHOD(void, HandleRegisterAgent, - (const rpc::RegisterAgentRequest &request, rpc::RegisterAgentReply *reply, + MOCK_METHOD(void, + HandleRegisterAgent, + (const rpc::RegisterAgentRequest &request, + rpc::RegisterAgentReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); }; diff --git a/src/mock/ray/raylet/dependency_manager.h b/src/mock/ray/raylet/dependency_manager.h index 9b1e32b50..4523dc0ac 100644 --- a/src/mock/ray/raylet/dependency_manager.h +++ b/src/mock/ray/raylet/dependency_manager.h @@ -17,7 +17,8 @@ namespace raylet { class MockTaskDependencyManagerInterface : public TaskDependencyManagerInterface { public: - MOCK_METHOD(bool, RequestTaskDependencies, + MOCK_METHOD(bool, + RequestTaskDependencies, (const TaskID &task_id, const std::vector &required_objects), (override)); diff --git a/src/mock/ray/raylet/node_manager.h b/src/mock/ray/raylet/node_manager.h index 1ce3563ba..c4ae5b33c 100644 --- a/src/mock/ray/raylet/node_manager.h +++ b/src/mock/ray/raylet/node_manager.h @@ -37,88 +37,110 @@ namespace raylet { class MockNodeManager : public NodeManager { public: - MOCK_METHOD(void, HandleUpdateResourceUsage, + MOCK_METHOD(void, + HandleUpdateResourceUsage, (const rpc::UpdateResourceUsageRequest &request, rpc::UpdateResourceUsageReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleRequestResourceReport, + MOCK_METHOD(void, + HandleRequestResourceReport, (const rpc::RequestResourceReportRequest &request, rpc::RequestResourceReportReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandlePrepareBundleResources, + MOCK_METHOD(void, + HandlePrepareBundleResources, (const rpc::PrepareBundleResourcesRequest &request, rpc::PrepareBundleResourcesReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleCommitBundleResources, + MOCK_METHOD(void, + HandleCommitBundleResources, (const rpc::CommitBundleResourcesRequest &request, rpc::CommitBundleResourcesReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleCancelResourceReserve, + MOCK_METHOD(void, + HandleCancelResourceReserve, (const rpc::CancelResourceReserveRequest &request, rpc::CancelResourceReserveReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleRequestWorkerLease, + MOCK_METHOD(void, + HandleRequestWorkerLease, (const rpc::RequestWorkerLeaseRequest &request, rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleReportWorkerBacklog, + MOCK_METHOD(void, + HandleReportWorkerBacklog, (const rpc::ReportWorkerBacklogRequest &request, rpc::ReportWorkerBacklogReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleReturnWorker, - (const rpc::ReturnWorkerRequest &request, rpc::ReturnWorkerReply *reply, + MOCK_METHOD(void, + HandleReturnWorker, + (const rpc::ReturnWorkerRequest &request, + rpc::ReturnWorkerReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleReleaseUnusedWorkers, + MOCK_METHOD(void, + HandleReleaseUnusedWorkers, (const rpc::ReleaseUnusedWorkersRequest &request, rpc::ReleaseUnusedWorkersReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleCancelWorkerLease, + MOCK_METHOD(void, + HandleCancelWorkerLease, (const rpc::CancelWorkerLeaseRequest &request, rpc::CancelWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandlePinObjectIDs, - (const rpc::PinObjectIDsRequest &request, rpc::PinObjectIDsReply *reply, + MOCK_METHOD(void, + HandlePinObjectIDs, + (const rpc::PinObjectIDsRequest &request, + rpc::PinObjectIDsReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetNodeStats, - (const rpc::GetNodeStatsRequest &request, rpc::GetNodeStatsReply *reply, + MOCK_METHOD(void, + HandleGetNodeStats, + (const rpc::GetNodeStatsRequest &request, + rpc::GetNodeStatsReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGlobalGC, - (const rpc::GlobalGCRequest &request, rpc::GlobalGCReply *reply, + MOCK_METHOD(void, + HandleGlobalGC, + (const rpc::GlobalGCRequest &request, + rpc::GlobalGCReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleFormatGlobalMemoryInfo, + MOCK_METHOD(void, + HandleFormatGlobalMemoryInfo, (const rpc::FormatGlobalMemoryInfoRequest &request, rpc::FormatGlobalMemoryInfoReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleRequestObjectSpillage, + MOCK_METHOD(void, + HandleRequestObjectSpillage, (const rpc::RequestObjectSpillageRequest &request, rpc::RequestObjectSpillageReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleReleaseUnusedBundles, + MOCK_METHOD(void, + HandleReleaseUnusedBundles, (const rpc::ReleaseUnusedBundlesRequest &request, rpc::ReleaseUnusedBundlesReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetSystemConfig, + MOCK_METHOD(void, + HandleGetSystemConfig, (const rpc::GetSystemConfigRequest &request, rpc::GetSystemConfigReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(void, HandleGetGcsServerAddress, + MOCK_METHOD(void, + HandleGetGcsServerAddress, (const rpc::GetGcsServerAddressRequest &request, rpc::GetGcsServerAddressReply *reply, rpc::SendReplyCallback send_reply_callback), diff --git a/src/mock/ray/raylet/scheduling/cluster_task_manager.h b/src/mock/ray/raylet/scheduling/cluster_task_manager.h index 21879b96f..d251ef0a8 100644 --- a/src/mock/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/mock/ray/raylet/scheduling/cluster_task_manager.h @@ -27,31 +27,49 @@ namespace raylet { class MockClusterTaskManager : public ClusterTaskManager { public: - MOCK_METHOD(void, QueueAndScheduleTask, - (const RayTask &task, rpc::RequestWorkerLeaseReply *reply, + MOCK_METHOD(void, + QueueAndScheduleTask, + (const RayTask &task, + rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); MOCK_METHOD(void, TasksUnblocked, (const std::vector &ready_ids), (override)); - MOCK_METHOD(void, TaskFinished, - (std::shared_ptr worker, RayTask *task), (override)); - MOCK_METHOD(bool, CancelTask, (const TaskID &task_id, bool runtime_env_setup_failed), + MOCK_METHOD(void, + TaskFinished, + (std::shared_ptr worker, RayTask *task), (override)); - MOCK_METHOD(void, FillPendingActorInfo, (rpc::GetNodeStatsReply * reply), + MOCK_METHOD(bool, + CancelTask, + (const TaskID &task_id, bool runtime_env_setup_failed), + (override)); + MOCK_METHOD(void, + FillPendingActorInfo, + (rpc::GetNodeStatsReply * reply), (const, override)); - MOCK_METHOD(void, FillResourceUsage, + MOCK_METHOD(void, + FillResourceUsage, (rpc::ResourcesData & data, const std::shared_ptr &last_reported_resources), (override)); - MOCK_METHOD(bool, AnyPendingTasksForResourceAcquisition, - (RayTask * exemplar, bool *any_pending, int *num_pending_actor_creation, + MOCK_METHOD(bool, + AnyPendingTasksForResourceAcquisition, + (RayTask * exemplar, + bool *any_pending, + int *num_pending_actor_creation, int *num_pending_tasks), (const, override)); - MOCK_METHOD(void, ReleaseWorkerResources, (std::shared_ptr worker), + MOCK_METHOD(void, + ReleaseWorkerResources, + (std::shared_ptr worker), + (override)); + MOCK_METHOD(bool, + ReleaseCpuResourcesFromUnblockedWorker, + (std::shared_ptr worker), + (override)); + MOCK_METHOD(bool, + ReturnCpuResourcesToBlockedWorker, + (std::shared_ptr worker), (override)); - MOCK_METHOD(bool, ReleaseCpuResourcesFromUnblockedWorker, - (std::shared_ptr worker), (override)); - MOCK_METHOD(bool, ReturnCpuResourcesToBlockedWorker, - (std::shared_ptr worker), (override)); MOCK_METHOD(void, ScheduleAndDispatchTasks, (), (override)); MOCK_METHOD(void, RecordMetrics, (), (override)); MOCK_METHOD(std::string, DebugStr, (), (const, override)); diff --git a/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h b/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h index c32ed4f06..e75e0534b 100644 --- a/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h +++ b/src/mock/ray/raylet/scheduling/cluster_task_manager_interface.h @@ -17,33 +17,50 @@ namespace raylet { class MockClusterTaskManagerInterface : public ClusterTaskManagerInterface { public: - MOCK_METHOD(void, ReleaseWorkerResources, (std::shared_ptr worker), + MOCK_METHOD(void, + ReleaseWorkerResources, + (std::shared_ptr worker), + (override)); + MOCK_METHOD(bool, + ReleaseCpuResourcesFromUnblockedWorker, + (std::shared_ptr worker), + (override)); + MOCK_METHOD(bool, + ReturnCpuResourcesToBlockedWorker, + (std::shared_ptr worker), (override)); - MOCK_METHOD(bool, ReleaseCpuResourcesFromUnblockedWorker, - (std::shared_ptr worker), (override)); - MOCK_METHOD(bool, ReturnCpuResourcesToBlockedWorker, - (std::shared_ptr worker), (override)); MOCK_METHOD(void, ScheduleAndDispatchTasks, (), (override)); MOCK_METHOD(void, TasksUnblocked, (const std::vector &ready_ids), (override)); - MOCK_METHOD(void, FillResourceUsage, + MOCK_METHOD(void, + FillResourceUsage, (rpc::ResourcesData & data, const std::shared_ptr &last_reported_resources), (override)); - MOCK_METHOD(void, FillPendingActorInfo, (rpc::GetNodeStatsReply * reply), + MOCK_METHOD(void, + FillPendingActorInfo, + (rpc::GetNodeStatsReply * reply), (const, override)); - MOCK_METHOD(void, TaskFinished, - (std::shared_ptr worker, RayTask *task), (override)); - MOCK_METHOD(bool, CancelTask, + MOCK_METHOD(void, + TaskFinished, + (std::shared_ptr worker, RayTask *task), + (override)); + MOCK_METHOD(bool, + CancelTask, (const TaskID &task_id, rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, const std::string &scheduling_failure_message), (override)); - MOCK_METHOD(void, QueueAndScheduleTask, - (const RayTask &task, rpc::RequestWorkerLeaseReply *reply, + MOCK_METHOD(void, + QueueAndScheduleTask, + (const RayTask &task, + rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback), (override)); - MOCK_METHOD(bool, AnyPendingTasksForResourceAcquisition, - (RayTask * exemplar, bool *any_pending, int *num_pending_actor_creation, + MOCK_METHOD(bool, + AnyPendingTasksForResourceAcquisition, + (RayTask * exemplar, + bool *any_pending, + int *num_pending_actor_creation, int *num_pending_tasks), (const, override)); MOCK_METHOD(std::string, DebugStr, (), (const, override)); diff --git a/src/mock/ray/raylet/worker.h b/src/mock/ray/raylet/worker.h index bf1120e70..a296183d4 100644 --- a/src/mock/ray/raylet/worker.h +++ b/src/mock/ray/raylet/worker.h @@ -29,7 +29,9 @@ class MockWorkerInterface : public WorkerInterface { MOCK_METHOD(Language, GetLanguage, (), (const, override)); MOCK_METHOD(const std::string, IpAddress, (), (const, override)); MOCK_METHOD(void, Connect, (int port), (override)); - MOCK_METHOD(void, Connect, (std::shared_ptr rpc_client), + MOCK_METHOD(void, + Connect, + (std::shared_ptr rpc_client), (override)); MOCK_METHOD(int, Port, (), (const, override)); MOCK_METHOD(int, AssignedPort, (), (const, override)); @@ -38,7 +40,9 @@ class MockWorkerInterface : public WorkerInterface { MOCK_METHOD(const TaskID &, GetAssignedTaskId, (), (const, override)); MOCK_METHOD(bool, AddBlockedTaskId, (const TaskID &task_id), (override)); MOCK_METHOD(bool, RemoveBlockedTaskId, (const TaskID &task_id), (override)); - MOCK_METHOD(const std::unordered_set &, GetBlockedTaskIds, (), + MOCK_METHOD(const std::unordered_set &, + GetBlockedTaskIds, + (), (const, override)); MOCK_METHOD(const JobID &, GetAssignedJobId, (), (const, override)); MOCK_METHOD(int, GetRuntimeEnvHash, (), (const, override)); @@ -52,16 +56,22 @@ class MockWorkerInterface : public WorkerInterface { MOCK_METHOD(void, DirectActorCallArgWaitComplete, (int64_t tag), (override)); MOCK_METHOD(const BundleID &, GetBundleId, (), (const, override)); MOCK_METHOD(void, SetBundleId, (const BundleID &bundle_id), (override)); - MOCK_METHOD(void, SetAllocatedInstances, + MOCK_METHOD(void, + SetAllocatedInstances, (const std::shared_ptr &allocated_instances), (override)); - MOCK_METHOD(std::shared_ptr, GetAllocatedInstances, (), + MOCK_METHOD(std::shared_ptr, + GetAllocatedInstances, + (), (override)); MOCK_METHOD(void, ClearAllocatedInstances, (), (override)); - MOCK_METHOD(void, SetLifetimeAllocatedInstances, + MOCK_METHOD(void, + SetLifetimeAllocatedInstances, (const std::shared_ptr &allocated_instances), (override)); - MOCK_METHOD(std::shared_ptr, GetLifetimeAllocatedInstances, (), + MOCK_METHOD(std::shared_ptr, + GetLifetimeAllocatedInstances, + (), (override)); MOCK_METHOD(void, ClearLifetimeAllocatedInstances, (), (override)); MOCK_METHOD(RayTask &, GetAssignedTask, (), (override)); diff --git a/src/mock/ray/raylet/worker_pool.h b/src/mock/ray/raylet/worker_pool.h index adc21d59b..b3ded72ac 100644 --- a/src/mock/ray/raylet/worker_pool.h +++ b/src/mock/ray/raylet/worker_pool.h @@ -17,14 +17,20 @@ namespace raylet { class MockWorkerPoolInterface : public WorkerPoolInterface { public: - MOCK_METHOD(void, PopWorker, - (const TaskSpecification &task_spec, const PopWorkerCallback &callback, + MOCK_METHOD(void, + PopWorker, + (const TaskSpecification &task_spec, + const PopWorkerCallback &callback, const std::string &allocated_instances_serialized_json), (override)); - MOCK_METHOD(void, PushWorker, (const std::shared_ptr &worker), + MOCK_METHOD(void, + PushWorker, + (const std::shared_ptr &worker), (override)); MOCK_METHOD(const std::vector>, - GetAllRegisteredWorkers, (bool filter_dead_workers), (override)); + GetAllRegisteredWorkers, + (bool filter_dead_workers), + (override)); }; } // namespace raylet @@ -35,24 +41,36 @@ namespace raylet { class MockIOWorkerPoolInterface : public IOWorkerPoolInterface { public: - MOCK_METHOD(void, PushSpillWorker, (const std::shared_ptr &worker), + MOCK_METHOD(void, + PushSpillWorker, + (const std::shared_ptr &worker), (override)); - MOCK_METHOD(void, PopSpillWorker, + MOCK_METHOD(void, + PopSpillWorker, (std::function)> callback), (override)); - MOCK_METHOD(void, PushRestoreWorker, (const std::shared_ptr &worker), + MOCK_METHOD(void, + PushRestoreWorker, + (const std::shared_ptr &worker), (override)); - MOCK_METHOD(void, PopRestoreWorker, + MOCK_METHOD(void, + PopRestoreWorker, (std::function)> callback), (override)); - MOCK_METHOD(void, PushDeleteWorker, (const std::shared_ptr &worker), + MOCK_METHOD(void, + PushDeleteWorker, + (const std::shared_ptr &worker), (override)); - MOCK_METHOD(void, PopDeleteWorker, + MOCK_METHOD(void, + PopDeleteWorker, (std::function)> callback), (override)); - MOCK_METHOD(void, PushUtilWorker, (const std::shared_ptr &worker), + MOCK_METHOD(void, + PushUtilWorker, + (const std::shared_ptr &worker), (override)); - MOCK_METHOD(void, PopUtilWorker, + MOCK_METHOD(void, + PopUtilWorker, (std::function)> callback), (override)); }; @@ -65,13 +83,16 @@ namespace raylet { class MockWorkerPool : public WorkerPool { public: - MOCK_METHOD(Process, StartProcess, + MOCK_METHOD(Process, + StartProcess, (const std::vector &worker_command_args, const ProcessEnvironment &env), (override)); MOCK_METHOD(void, WarnAboutSize, (), (override)); - MOCK_METHOD(void, PopWorkerCallbackAsync, - (const PopWorkerCallback &callback, std::shared_ptr worker, + MOCK_METHOD(void, + PopWorkerCallbackAsync, + (const PopWorkerCallback &callback, + std::shared_ptr worker, PopWorkerStatus status), (override)); }; diff --git a/src/mock/ray/raylet_client/raylet_client.h b/src/mock/ray/raylet_client/raylet_client.h index 528d50c73..1de4d76b8 100644 --- a/src/mock/ray/raylet_client/raylet_client.h +++ b/src/mock/ray/raylet_client/raylet_client.h @@ -16,7 +16,8 @@ namespace ray { class MockPinObjectsInterface : public PinObjectsInterface { public: - MOCK_METHOD(void, PinObjectIDs, + MOCK_METHOD(void, + PinObjectIDs, (const rpc::Address &caller_address, const std::vector &object_ids, const ray::rpc::ClientCallback &callback), @@ -30,20 +31,28 @@ namespace ray { class MockWorkerLeaseInterface : public WorkerLeaseInterface { public: MOCK_METHOD( - void, RequestWorkerLease, - (const rpc::TaskSpec &task_spec, bool grant_or_reject, + void, + RequestWorkerLease, + (const rpc::TaskSpec &task_spec, + bool grant_or_reject, const ray::rpc::ClientCallback &callback, - const int64_t backlog_size, const bool is_selected_based_on_locality), + const int64_t backlog_size, + const bool is_selected_based_on_locality), (override)); - MOCK_METHOD(ray::Status, ReturnWorker, - (int worker_port, const WorkerID &worker_id, bool disconnect_worker, + MOCK_METHOD(ray::Status, + ReturnWorker, + (int worker_port, + const WorkerID &worker_id, + bool disconnect_worker, bool worker_exiting), (override)); - MOCK_METHOD(void, ReleaseUnusedWorkers, + MOCK_METHOD(void, + ReleaseUnusedWorkers, (const std::vector &workers_in_use, const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, CancelWorkerLease, + MOCK_METHOD(void, + CancelWorkerLease, (const TaskID &task_id, const rpc::ClientCallback &callback), (override)); @@ -56,21 +65,25 @@ namespace ray { class MockResourceReserveInterface : public ResourceReserveInterface { public: MOCK_METHOD( - void, PrepareBundleResources, + void, + PrepareBundleResources, (const std::vector> &bundle_specs, const ray::rpc::ClientCallback &callback), (override)); MOCK_METHOD( - void, CommitBundleResources, + void, + CommitBundleResources, (const std::vector> &bundle_specs, const ray::rpc::ClientCallback &callback), (override)); MOCK_METHOD( - void, CancelResourceReserve, + void, + CancelResourceReserve, (const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, ReleaseUnusedBundles, + MOCK_METHOD(void, + ReleaseUnusedBundles, (const std::vector &bundles_in_use, const rpc::ClientCallback &callback), (override)); @@ -82,7 +95,8 @@ namespace ray { class MockDependencyWaiterInterface : public DependencyWaiterInterface { public: - MOCK_METHOD(ray::Status, WaitForDirectActorCallArgs, + MOCK_METHOD(ray::Status, + WaitForDirectActorCallArgs, (const std::vector &references, int64_t tag), (override)); }; @@ -93,11 +107,13 @@ namespace ray { class MockResourceTrackingInterface : public ResourceTrackingInterface { public: - MOCK_METHOD(void, UpdateResourceUsage, + MOCK_METHOD(void, + UpdateResourceUsage, (std::string & serialized_resource_usage_batch, const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, RequestResourceReport, + MOCK_METHOD(void, + RequestResourceReport, (const rpc::ClientCallback &callback), (override)); }; @@ -108,71 +124,92 @@ namespace ray { class MockRayletClientInterface : public RayletClientInterface { public: - MOCK_METHOD(ray::Status, WaitForDirectActorCallArgs, + MOCK_METHOD(ray::Status, + WaitForDirectActorCallArgs, (const std::vector &references, int64_t tag), (override)); - MOCK_METHOD(void, ReportWorkerBacklog, + MOCK_METHOD(void, + ReportWorkerBacklog, (const WorkerID &worker_id, const std::vector &backlog_reports), (override)); MOCK_METHOD( - void, RequestWorkerLease, - (const rpc::TaskSpec &resource_spec, bool grant_or_reject, + void, + RequestWorkerLease, + (const rpc::TaskSpec &resource_spec, + bool grant_or_reject, const ray::rpc::ClientCallback &callback, - const int64_t backlog_size, const bool is_selected_based_on_locality), + const int64_t backlog_size, + const bool is_selected_based_on_locality), (override)); - MOCK_METHOD(ray::Status, ReturnWorker, - (int worker_port, const WorkerID &worker_id, bool disconnect_worker, + MOCK_METHOD(ray::Status, + ReturnWorker, + (int worker_port, + const WorkerID &worker_id, + bool disconnect_worker, bool worker_exiting), (override)); - MOCK_METHOD(void, ReleaseUnusedWorkers, + MOCK_METHOD(void, + ReleaseUnusedWorkers, (const std::vector &workers_in_use, const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, CancelWorkerLease, + MOCK_METHOD(void, + CancelWorkerLease, (const TaskID &task_id, const rpc::ClientCallback &callback), (override)); MOCK_METHOD( - void, PrepareBundleResources, + void, + PrepareBundleResources, (const std::vector> &bundle_specs, const ray::rpc::ClientCallback &callback), (override)); MOCK_METHOD( - void, CommitBundleResources, + void, + CommitBundleResources, (const std::vector> &bundle_specs, const ray::rpc::ClientCallback &callback), (override)); MOCK_METHOD( - void, CancelResourceReserve, + void, + CancelResourceReserve, (const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, ReleaseUnusedBundles, + MOCK_METHOD(void, + ReleaseUnusedBundles, (const std::vector &bundles_in_use, const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, PinObjectIDs, + MOCK_METHOD(void, + PinObjectIDs, (const rpc::Address &caller_address, const std::vector &object_ids, const ray::rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, GetSystemConfig, + MOCK_METHOD(void, + GetSystemConfig, (const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, GetGcsServerAddress, + MOCK_METHOD(void, + GetGcsServerAddress, (const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, UpdateResourceUsage, + MOCK_METHOD(void, + UpdateResourceUsage, (std::string & serialized_resource_usage_batch, const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, RequestResourceReport, + MOCK_METHOD(void, + RequestResourceReport, (const rpc::ClientCallback &callback), (override)); - MOCK_METHOD(void, ShutdownRaylet, - (const NodeID &node_id, bool graceful, + MOCK_METHOD(void, + ShutdownRaylet, + (const NodeID &node_id, + bool graceful, const rpc::ClientCallback &callback), (override)); }; diff --git a/src/mock/ray/rpc/worker/core_worker_client.h b/src/mock/ray/rpc/worker/core_worker_client.h index f4b7bcc79..7b67aac50 100644 --- a/src/mock/ray/rpc/worker/core_worker_client.h +++ b/src/mock/ray/rpc/worker/core_worker_client.h @@ -29,86 +29,108 @@ class MockCoreWorkerClientInterface : public ray::pubsub::MockSubscriberClientIn public CoreWorkerClientInterface { public: MOCK_METHOD(const rpc::Address &, Addr, (), (const, override)); - MOCK_METHOD(void, PushActorTask, - (std::unique_ptr request, bool skip_queue, + MOCK_METHOD(void, + PushActorTask, + (std::unique_ptr request, + bool skip_queue, const ClientCallback &callback), (override)); - MOCK_METHOD(void, PushNormalTask, + MOCK_METHOD(void, + PushNormalTask, (std::unique_ptr request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, DirectActorCallArgWaitComplete, + MOCK_METHOD(void, + DirectActorCallArgWaitComplete, (const DirectActorCallArgWaitCompleteRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, GetObjectStatus, + MOCK_METHOD(void, + GetObjectStatus, (const GetObjectStatusRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, WaitForActorOutOfScope, + MOCK_METHOD(void, + WaitForActorOutOfScope, (const WaitForActorOutOfScopeRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, PubsubLongPolling, + MOCK_METHOD(void, + PubsubLongPolling, (const PubsubLongPollingRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, PubsubCommandBatch, + MOCK_METHOD(void, + PubsubCommandBatch, (const PubsubCommandBatchRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, UpdateObjectLocationBatch, + MOCK_METHOD(void, + UpdateObjectLocationBatch, (const UpdateObjectLocationBatchRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, GetObjectLocationsOwner, + MOCK_METHOD(void, + GetObjectLocationsOwner, (const GetObjectLocationsOwnerRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, KillActor, + MOCK_METHOD(void, + KillActor, (const KillActorRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, CancelTask, + MOCK_METHOD(void, + CancelTask, (const CancelTaskRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, RemoteCancelTask, + MOCK_METHOD(void, + RemoteCancelTask, (const RemoteCancelTaskRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, GetCoreWorkerStats, + MOCK_METHOD(void, + GetCoreWorkerStats, (const GetCoreWorkerStatsRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, LocalGC, + MOCK_METHOD(void, + LocalGC, (const LocalGCRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, SpillObjects, + MOCK_METHOD(void, + SpillObjects, (const SpillObjectsRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, RestoreSpilledObjects, + MOCK_METHOD(void, + RestoreSpilledObjects, (const RestoreSpilledObjectsRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, DeleteSpilledObjects, + MOCK_METHOD(void, + DeleteSpilledObjects, (const DeleteSpilledObjectsRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, AddSpilledUrl, + MOCK_METHOD(void, + AddSpilledUrl, (const AddSpilledUrlRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, PlasmaObjectReady, + MOCK_METHOD(void, + PlasmaObjectReady, (const PlasmaObjectReadyRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, Exit, + MOCK_METHOD(void, + Exit, (const ExitRequest &request, const ClientCallback &callback), (override)); - MOCK_METHOD(void, AssignObjectOwner, + MOCK_METHOD(void, + AssignObjectOwner, (const AssignObjectOwnerRequest &request, const ClientCallback &callback), (override)); diff --git a/src/ray/common/asio/asio_util.h b/src/ray/common/asio/asio_util.h index cea1b1d5a..0fa69d0f8 100644 --- a/src/ray/common/asio/asio_util.h +++ b/src/ray/common/asio/asio_util.h @@ -17,7 +17,8 @@ #include inline std::shared_ptr execute_after_us( - instrumented_io_context &io_context, std::function fn, + instrumented_io_context &io_context, + std::function fn, int64_t delay_microseconds) { auto timer = std::make_shared(io_context); timer->expires_from_now(boost::posix_time::microseconds(delay_microseconds)); diff --git a/src/ray/common/asio/io_service_pool.h b/src/ray/common/asio/io_service_pool.h index 6943aca21..b80599844 100644 --- a/src/ray/common/asio/io_service_pool.h +++ b/src/ray/common/asio/io_service_pool.h @@ -17,6 +17,7 @@ #include #include #include + #include "ray/common/asio/instrumented_io_context.h" namespace ray { diff --git a/src/ray/common/asio/periodical_runner.cc b/src/ray/common/asio/periodical_runner.cc index a6d771495..98b3218f3 100644 --- a/src/ray/common/asio/periodical_runner.cc +++ b/src/ray/common/asio/periodical_runner.cc @@ -31,7 +31,8 @@ PeriodicalRunner::~PeriodicalRunner() { RAY_LOG(DEBUG) << "PeriodicalRunner is destructed"; } -void PeriodicalRunner::RunFnPeriodically(std::function fn, uint64_t period_ms, +void PeriodicalRunner::RunFnPeriodically(std::function fn, + uint64_t period_ms, const std::string name) { if (period_ms > 0) { auto timer = std::make_shared(io_service_); @@ -53,13 +54,14 @@ void PeriodicalRunner::RunFnPeriodically(std::function fn, uint64_t peri } void PeriodicalRunner::DoRunFnPeriodically( - const std::function &fn, boost::posix_time::milliseconds period, + const std::function &fn, + boost::posix_time::milliseconds period, std::shared_ptr timer) { fn(); absl::MutexLock lock(&mutex_); timer->expires_from_now(period); - timer->async_wait([this, fn = std::move(fn), period, - timer = std::move(timer)](const boost::system::error_code &error) { + timer->async_wait([this, fn = std::move(fn), period, timer = std::move(timer)]( + const boost::system::error_code &error) { if (error == boost::asio::error::operation_aborted) { // `operation_aborted` is set when `timer` is canceled or destroyed. // The Monitor lifetime may be short than the object who use it. (e.g. @@ -72,8 +74,10 @@ void PeriodicalRunner::DoRunFnPeriodically( } void PeriodicalRunner::DoRunFnPeriodicallyInstrumented( - const std::function &fn, boost::posix_time::milliseconds period, - std::shared_ptr timer, const std::string name) { + const std::function &fn, + boost::posix_time::milliseconds period, + std::shared_ptr timer, + const std::string name) { fn(); absl::MutexLock lock(&mutex_); timer->expires_from_now(period); @@ -81,7 +85,10 @@ void PeriodicalRunner::DoRunFnPeriodicallyInstrumented( // which the handler was elgible to execute on the event loop but was queued by the // event loop. auto stats_handle = io_service_.stats().RecordStart(name, period.total_nanoseconds()); - timer->async_wait([this, fn = std::move(fn), period, timer = std::move(timer), + timer->async_wait([this, + fn = std::move(fn), + period, + timer = std::move(timer), stats_handle = std::move(stats_handle), name](const boost::system::error_code &error) { io_service_.stats().RecordExecution( diff --git a/src/ray/common/asio/periodical_runner.h b/src/ray/common/asio/periodical_runner.h index 6992158d2..94e9f4cbd 100644 --- a/src/ray/common/asio/periodical_runner.h +++ b/src/ray/common/asio/periodical_runner.h @@ -34,7 +34,8 @@ class PeriodicalRunner { ~PeriodicalRunner(); - void RunFnPeriodically(std::function fn, uint64_t period_ms, + void RunFnPeriodically(std::function fn, + uint64_t period_ms, const std::string name = "UNKNOWN") LOCKS_EXCLUDED(mutex_); private: diff --git a/src/ray/common/buffer.h b/src/ray/common/buffer.h index e75b1ebb9..0ce0829d0 100644 --- a/src/ray/common/buffer.h +++ b/src/ray/common/buffer.h @@ -143,7 +143,8 @@ class SharedMemoryBuffer : public Buffer { } static std::shared_ptr Slice(const std::shared_ptr &buffer, - int64_t offset, int64_t size) { + int64_t offset, + int64_t size) { return std::make_shared(buffer, offset, size); } diff --git a/src/ray/common/bundle_spec.cc b/src/ray/common/bundle_spec.cc index 65ab06c9b..7409ac06d 100644 --- a/src/ray/common/bundle_spec.cc +++ b/src/ray/common/bundle_spec.cc @@ -139,7 +139,8 @@ std::string FormatPlacementGroupResource(const std::string &original_resource_na original_resource_name, bundle_spec.PlacementGroupId(), bundle_spec.Index()); } -bool IsBundleIndex(const std::string &resource, const PlacementGroupID &group_id, +bool IsBundleIndex(const std::string &resource, + const PlacementGroupID &group_id, const int bundle_index) { return resource.find(kGroupKeyword + std::to_string(bundle_index) + "_" + group_id.Hex()) != std::string::npos; diff --git a/src/ray/common/bundle_spec.h b/src/ray/common/bundle_spec.h index 0e3385ada..73122f351 100644 --- a/src/ray/common/bundle_spec.h +++ b/src/ray/common/bundle_spec.h @@ -102,7 +102,8 @@ std::string FormatPlacementGroupResource(const std::string &original_resource_na const BundleSpecification &bundle_spec); /// Return whether a formatted resource is a bundle of the given index. -bool IsBundleIndex(const std::string &resource, const PlacementGroupID &group_id, +bool IsBundleIndex(const std::string &resource, + const PlacementGroupID &group_id, const int bundle_index); /// Return the original resource name of the placement group resource. diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index 68ffaaf8f..b6f2308af 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -30,8 +30,10 @@ namespace ray { -Status ConnectSocketRetry(local_stream_socket &socket, const std::string &endpoint, - int num_retries, int64_t timeout_in_ms) { +Status ConnectSocketRetry(local_stream_socket &socket, + const std::string &endpoint, + int num_retries, + int64_t timeout_in_ms) { RAY_CHECK(num_retries != 0); // Pick the default values if the user did not specify. if (num_retries < 0) { @@ -114,7 +116,8 @@ void ServerConnection::WriteBufferAsync( const auto stats_handle = io_context.stats().RecordStart("ClientConnection.async_write.WriteBufferAsync"); boost::asio::async_write( - socket_, buffer, + socket_, + buffer, [handler, stats_handle = std::move(stats_handle)]( const boost::system::error_code &ec, size_t bytes_transferred) { EventTracker::RecordExecution( @@ -123,7 +126,8 @@ void ServerConnection::WriteBufferAsync( }); } else { boost::asio::async_write( - socket_, buffer, + socket_, + buffer, [handler](const boost::system::error_code &ec, size_t bytes_transferred) { handler(boost_to_ray_status(ec)); }); @@ -162,7 +166,8 @@ void ServerConnection::ReadBufferAsync( const auto stats_handle = io_context.stats().RecordStart("ClientConnection.async_read.ReadBufferAsync"); boost::asio::async_read( - socket_, buffer, + socket_, + buffer, [handler, stats_handle = std::move(stats_handle)]( const boost::system::error_code &ec, size_t bytes_transferred) { EventTracker::RecordExecution( @@ -171,14 +176,16 @@ void ServerConnection::ReadBufferAsync( }); } else { boost::asio::async_read( - socket_, buffer, + socket_, + buffer, [handler](const boost::system::error_code &ec, size_t bytes_transferred) { handler(boost_to_ray_status(ec)); }); } } -ray::Status ServerConnection::WriteMessage(int64_t type, int64_t length, +ray::Status ServerConnection::WriteMessage(int64_t type, + int64_t length, const uint8_t *message) { sync_writes_ += 1; bytes_written_ += length; @@ -218,7 +225,9 @@ Status ServerConnection::ReadMessage(int64_t type, std::vector *message } void ServerConnection::WriteMessageAsync( - int64_t type, int64_t length, const uint8_t *message, + int64_t type, + int64_t length, + const uint8_t *message, const std::function &handler) { async_writes_ += 1; bytes_written_ += length; @@ -294,8 +303,12 @@ void ServerConnection::DoAsyncWrites() { const auto stats_handle = io_context.stats().RecordStart("ClientConnection.async_write.DoAsyncWrites"); boost::asio::async_write( - socket_, message_buffers, - [this, this_ptr, num_messages, call_handlers, + socket_, + message_buffers, + [this, + this_ptr, + num_messages, + call_handlers, stats_handle = std::move(stats_handle)](const boost::system::error_code &error, size_t bytes_transferred) { EventTracker::RecordExecution( @@ -319,7 +332,8 @@ void ServerConnection::DoAsyncWrites() { }); } else { boost::asio::async_write( - ServerConnection::socket_, message_buffers, + ServerConnection::socket_, + message_buffers, [this, this_ptr, num_messages, call_handlers]( const boost::system::error_code &error, size_t bytes_transferred) { ray::Status status = boost_to_ray_status(error); @@ -341,22 +355,30 @@ void ServerConnection::DoAsyncWrites() { } std::shared_ptr ClientConnection::Create( - ClientHandler &client_handler, MessageHandler &message_handler, - local_stream_socket &&socket, const std::string &debug_label, - const std::vector &message_type_enum_names, int64_t error_message_type, + ClientHandler &client_handler, + MessageHandler &message_handler, + local_stream_socket &&socket, + const std::string &debug_label, + const std::vector &message_type_enum_names, + int64_t error_message_type, const std::vector &error_message_data) { - std::shared_ptr self(new ClientConnection( - message_handler, std::move(socket), debug_label, message_type_enum_names, - error_message_type, error_message_data)); + std::shared_ptr self(new ClientConnection(message_handler, + std::move(socket), + debug_label, + message_type_enum_names, + error_message_type, + error_message_data)); // Let our manager process our new connection. client_handler(*self); return self; } ClientConnection::ClientConnection( - MessageHandler &message_handler, local_stream_socket &&socket, + MessageHandler &message_handler, + local_stream_socket &&socket, const std::string &debug_label, - const std::vector &message_type_enum_names, int64_t error_message_type, + const std::vector &message_type_enum_names, + int64_t error_message_type, const std::vector &error_message_data) : ServerConnection(std::move(socket)), registered_(false), @@ -386,7 +408,8 @@ void ClientConnection::ProcessMessages() { const auto stats_handle = io_context.stats().RecordStart("ClientConnection.async_read.ReadBufferAsync"); boost::asio::async_read( - ServerConnection::socket_, header, + ServerConnection::socket_, + header, [this, this_ptr, stats_handle = std::move(stats_handle)]( const boost::system::error_code &ec, size_t bytes_transferred) { EventTracker::RecordExecution( @@ -394,7 +417,8 @@ void ClientConnection::ProcessMessages() { std::move(stats_handle)); }); } else { - boost::asio::async_read(ServerConnection::socket_, header, + boost::asio::async_read(ServerConnection::socket_, + header, boost::bind(&ClientConnection::ProcessMessageHeader, shared_ClientConnection_from_this(), boost::asio::placeholders::error)); @@ -428,14 +452,16 @@ void ClientConnection::ProcessMessageHeader(const boost::system::error_code &err const auto stats_handle = io_context.stats().RecordStart("ClientConnection.async_read.ReadBufferAsync"); boost::asio::async_read( - ServerConnection::socket_, boost::asio::buffer(read_message_), + ServerConnection::socket_, + boost::asio::buffer(read_message_), [this, this_ptr, stats_handle = std::move(stats_handle)]( const boost::system::error_code &ec, size_t bytes_transferred) { EventTracker::RecordExecution([this, this_ptr, ec]() { ProcessMessage(ec); }, std::move(stats_handle)); }); } else { - boost::asio::async_read(ServerConnection::socket_, boost::asio::buffer(read_message_), + boost::asio::async_read(ServerConnection::socket_, + boost::asio::buffer(read_message_), boost::bind(&ClientConnection::ProcessMessage, shared_ClientConnection_from_this(), boost::asio::placeholders::error)); diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 129ca69f5..6f3613782 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -31,8 +31,10 @@ typedef boost::asio::generic::stream_protocol local_stream_protocol; typedef boost::asio::basic_stream_socket local_stream_socket; /// Connect to a socket with retry times. -Status ConnectSocketRetry(local_stream_socket &socket, const std::string &endpoint, - int num_retries = -1, int64_t timeout_in_ms = -1); +Status ConnectSocketRetry(local_stream_socket &socket, + const std::string &endpoint, + int num_retries = -1, + int64_t timeout_in_ms = -1); /// \typename ServerConnection /// @@ -63,7 +65,9 @@ class ServerConnection : public std::enable_shared_from_this { /// \param length The size in bytes of the message. /// \param message A pointer to the message buffer. /// \param handler A callback to run on write completion. - void WriteMessageAsync(int64_t type, int64_t length, const uint8_t *message, + void WriteMessageAsync(int64_t type, + int64_t length, + const uint8_t *message, const std::function &handler); /// Read a message from the client. @@ -169,8 +173,8 @@ class ServerConnection : public std::enable_shared_from_this { class ClientConnection; using ClientHandler = std::function; -using MessageHandler = std::function, int64_t, - const std::vector &)>; +using MessageHandler = std::function, int64_t, const std::vector &)>; static std::vector _dummy_error_message_data; /// \typename ClientConnection @@ -195,9 +199,12 @@ class ClientConnection : public ServerConnection { /// \param error_message_data the companion data to the error message type. /// \return std::shared_ptr. static std::shared_ptr Create( - ClientHandler &new_client_handler, MessageHandler &message_handler, - local_stream_socket &&socket, const std::string &debug_label, - const std::vector &message_type_enum_names, int64_t error_message_type, + ClientHandler &new_client_handler, + MessageHandler &message_handler, + local_stream_socket &&socket, + const std::string &debug_label, + const std::vector &message_type_enum_names, + int64_t error_message_type, const std::vector &error_message_data = _dummy_error_message_data); std::shared_ptr shared_ClientConnection_from_this() { @@ -215,9 +222,11 @@ class ClientConnection : public ServerConnection { protected: /// A protected constructor for a node client connection. ClientConnection( - MessageHandler &message_handler, local_stream_socket &&socket, + MessageHandler &message_handler, + local_stream_socket &&socket, const std::string &debug_label, - const std::vector &message_type_enum_names, int64_t error_message_type, + const std::vector &message_type_enum_names, + int64_t error_message_type, const std::vector &error_message_data = _dummy_error_message_data); /// Process an error from the last operation, then process the message /// header from the client. diff --git a/src/ray/common/event_stats.cc b/src/ray/common/event_stats.cc index b2e38fb77..c4f1badba 100644 --- a/src/ray/common/event_stats.cc +++ b/src/ray/common/event_stats.cc @@ -71,7 +71,9 @@ std::shared_ptr EventTracker::RecordStart( ray::stats::STATS_operation_count.Record(curr_count, name); ray::stats::STATS_operation_active_count.Record(curr_count, name); return std::make_shared( - name, absl::GetCurrentTimeNanos() + expected_queueing_delay_ns, std::move(stats), + name, + absl::GetCurrentTimeNanos() + expected_queueing_delay_ns, + std::move(stats), global_stats_); } @@ -165,7 +167,8 @@ std::vector> EventTracker::get_event_stats() absl::ReaderMutexLock lock(&mutex_); std::vector> stats; stats.reserve(post_handler_stats_.size()); - std::transform(post_handler_stats_.begin(), post_handler_stats_.end(), + std::transform(post_handler_stats_.begin(), + post_handler_stats_.end(), std::back_inserter(stats), [](const std::pair> &p) { return std::make_pair(p.first, to_event_stats_view(p.second)); @@ -181,7 +184,8 @@ std::string EventTracker::StatsString() const { } auto stats = get_event_stats(); // Sort stats by cumulative count, outside of the table lock. - sort(stats.begin(), stats.end(), + sort(stats.begin(), + stats.end(), [](const std::pair &a, const std::pair &b) { return a.second.cum_count > b.second.cum_count; diff --git a/src/ray/common/event_stats.h b/src/ray/common/event_stats.h index 4591f79b0..1af2321e3 100644 --- a/src/ray/common/event_stats.h +++ b/src/ray/common/event_stats.h @@ -71,7 +71,8 @@ struct StatsHandle { std::shared_ptr global_stats; std::atomic execution_recorded; - StatsHandle(std::string event_name_, int64_t start_time_, + StatsHandle(std::string event_name_, + int64_t start_time_, std::shared_ptr handler_stats_, std::shared_ptr global_stats_) : event_name(std::move(event_name_)), diff --git a/src/ray/common/function_descriptor.cc b/src/ray/common/function_descriptor.cc index f3a8bde99..332660886 100644 --- a/src/ray/common/function_descriptor.cc +++ b/src/ray/common/function_descriptor.cc @@ -33,8 +33,10 @@ FunctionDescriptor FunctionDescriptorBuilder::BuildJava(const std::string &class } FunctionDescriptor FunctionDescriptorBuilder::BuildPython( - const std::string &module_name, const std::string &class_name, - const std::string &function_name, const std::string &function_hash) { + const std::string &module_name, + const std::string &class_name, + const std::string &function_name, + const std::string &function_hash) { rpc::FunctionDescriptor descriptor; auto typed_descriptor = descriptor.mutable_python_function_descriptor(); typed_descriptor->set_module_name(module_name); diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index 175af57f5..52a898def 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -82,14 +82,16 @@ inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) { } else { // Unlike `UNKNOWN`, `ABORTED` is never generated by the library, so using it means // more robust. - return grpc::Status(grpc::StatusCode::ABORTED, ray_status.CodeAsString(), - ray_status.message()); + return grpc::Status( + grpc::StatusCode::ABORTED, ray_status.CodeAsString(), ray_status.message()); } } inline std::string GrpcStatusToRayStatusMessage(const grpc::Status &grpc_status) { - return absl::StrCat("RPC Error message: ", grpc_status.error_message(), - "; RPC Error details: ", grpc_status.error_details()); + return absl::StrCat("RPC Error message: ", + grpc_status.error_message(), + "; RPC Error details: ", + grpc_status.error_details()); } /// Helper function that converts a gRPC status to ray status. @@ -135,8 +137,8 @@ inline std::vector IdVectorFromProtobuf( const ::google::protobuf::RepeatedPtrField<::std::string> &pb_repeated) { auto str_vec = VectorFromProtobuf(pb_repeated); std::vector ret; - std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(ret), - &ID::FromBinary); + std::transform( + str_vec.begin(), str_vec.end(), std::back_inserter(ret), &ID::FromBinary); return ret; } diff --git a/src/ray/common/id.cc b/src/ray/common/id.cc index 16de6db17..897bced08 100644 --- a/src/ray/common/id.cc +++ b/src/ray/common/id.cc @@ -40,14 +40,17 @@ uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); /// A helper function to generate the unique bytes by hash. __suppress_ubsan__("undefined") std::string - GenerateUniqueBytes(const JobID &job_id, const TaskID &parent_task_id, - size_t parent_task_counter, size_t extra_bytes, size_t length) { + GenerateUniqueBytes(const JobID &job_id, + const TaskID &parent_task_id, + size_t parent_task_counter, + size_t extra_bytes, + size_t length) { RAY_CHECK(length <= DIGEST_SIZE); SHA256_CTX ctx; sha256_init(&ctx); sha256_update(&ctx, reinterpret_cast(job_id.Data()), job_id.Size()); - sha256_update(&ctx, reinterpret_cast(parent_task_id.Data()), - parent_task_id.Size()); + sha256_update( + &ctx, reinterpret_cast(parent_task_id.Data()), parent_task_id.Size()); sha256_update(&ctx, (const BYTE *)&parent_task_counter, sizeof(parent_task_counter)); if (extra_bytes > 0) { sha256_update(&ctx, (const BYTE *)&extra_bytes, sizeof(extra_bytes)); @@ -124,14 +127,17 @@ __suppress_ubsan__("undefined") uint64_t return h; } -ActorID ActorID::Of(const JobID &job_id, const TaskID &parent_task_id, +ActorID ActorID::Of(const JobID &job_id, + const TaskID &parent_task_id, const size_t parent_task_counter) { // NOTE(swang): Include the current time in the hash for the actor ID so that // we avoid duplicating a previous actor ID, which is not allowed by the GCS. // See https://github.com/ray-project/ray/issues/10481. - auto data = - GenerateUniqueBytes(job_id, parent_task_id, parent_task_counter, - absl::GetCurrentTimeNanos(), ActorID::kUniqueBytesLength); + auto data = GenerateUniqueBytes(job_id, + parent_task_id, + parent_task_counter, + absl::GetCurrentTimeNanos(), + ActorID::kUniqueBytesLength); std::copy_n(job_id.Data(), JobID::kLength, std::back_inserter(data)); RAY_CHECK(data.size() == kLength); return ActorID::FromBinary(data); @@ -175,19 +181,22 @@ TaskID TaskID::ForActorCreationTask(const ActorID &actor_id) { return TaskID::FromBinary(data); } -TaskID TaskID::ForActorTask(const JobID &job_id, const TaskID &parent_task_id, - size_t parent_task_counter, const ActorID &actor_id) { - std::string data = GenerateUniqueBytes(job_id, parent_task_id, parent_task_counter, 0, - TaskID::kUniqueBytesLength); +TaskID TaskID::ForActorTask(const JobID &job_id, + const TaskID &parent_task_id, + size_t parent_task_counter, + const ActorID &actor_id) { + std::string data = GenerateUniqueBytes( + job_id, parent_task_id, parent_task_counter, 0, TaskID::kUniqueBytesLength); std::copy_n(actor_id.Data(), ActorID::kLength, std::back_inserter(data)); RAY_CHECK(data.size() == TaskID::kLength); return TaskID::FromBinary(data); } -TaskID TaskID::ForNormalTask(const JobID &job_id, const TaskID &parent_task_id, +TaskID TaskID::ForNormalTask(const JobID &job_id, + const TaskID &parent_task_id, size_t parent_task_counter) { - std::string data = GenerateUniqueBytes(job_id, parent_task_id, parent_task_counter, 0, - TaskID::kUniqueBytesLength); + std::string data = GenerateUniqueBytes( + job_id, parent_task_id, parent_task_counter, 0, TaskID::kUniqueBytesLength); const auto dummy_actor_id = ActorID::NilFromJob(job_id); std::copy_n(dummy_actor_id.Data(), ActorID::kLength, std::back_inserter(data)); RAY_CHECK(data.size() == TaskID::kLength); diff --git a/src/ray/common/id.h b/src/ray/common/id.h index c20cffab8..5a43885fd 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -144,7 +144,8 @@ class ActorID : public BaseID { /// \param parent_task_counter The counter of the parent task. /// /// \return The random `ActorID`. - static ActorID Of(const JobID &job_id, const TaskID &parent_task_id, + static ActorID Of(const JobID &job_id, + const TaskID &parent_task_id, const size_t parent_task_counter); /// Creates a nil ActorID with the given job. @@ -210,8 +211,10 @@ class TaskID : public BaseID { /// \param actor_id The ID of the actor to which this task belongs. /// /// \return The ID of the actor task. - static TaskID ForActorTask(const JobID &job_id, const TaskID &parent_task_id, - size_t parent_task_counter, const ActorID &actor_id); + static TaskID ForActorTask(const JobID &job_id, + const TaskID &parent_task_id, + size_t parent_task_counter, + const ActorID &actor_id); /// Creates a TaskID for normal task. /// @@ -221,7 +224,8 @@ class TaskID : public BaseID { /// parent task before this one. /// /// \return The ID of the normal task. - static TaskID ForNormalTask(const JobID &job_id, const TaskID &parent_task_id, + static TaskID ForNormalTask(const JobID &job_id, + const TaskID &parent_task_id, size_t parent_task_counter); /// Given a base task ID, create a task ID that represents the n-th execution diff --git a/src/ray/common/network_util.cc b/src/ray/common/network_util.cc index d0cebcc20..6a2ee33ad 100644 --- a/src/ray/common/network_util.cc +++ b/src/ray/common/network_util.cc @@ -23,8 +23,11 @@ std::string GetValidLocalIp(int port, int64_t timeout_ms) { const std::string localhost_ip = "127.0.0.1"; bool is_timeout; - if (async_client.Connect(kPublicDNSServerIp, kPublicDNSServerPort, timeout_ms, - &is_timeout, &error_code)) { + if (async_client.Connect(kPublicDNSServerIp, + kPublicDNSServerPort, + timeout_ms, + &is_timeout, + &error_code)) { address = async_client.GetLocalIPAddress(); } else { address = localhost_ip; @@ -41,7 +44,9 @@ std::string GetValidLocalIp(int port, int64_t timeout_ms) { primary_endpoint.address(ip_candidate); AsyncClient client; - if (client.Connect(primary_endpoint.address().to_string(), port, timeout_ms, + if (client.Connect(primary_endpoint.address().to_string(), + port, + timeout_ms, &is_timeout)) { success = true; break; @@ -110,8 +115,8 @@ std::vector GetValidLocalIpCandidates() { freeifaddrs(ifs_info); // Bigger prefixes must be tested first in CompNameAndIps - std::sort(prefixes_and_priorities.begin(), prefixes_and_priorities.end(), - CompPrefixLen); + std::sort( + prefixes_and_priorities.begin(), prefixes_and_priorities.end(), CompPrefixLen); // Filter out interfaces with small possibility of being desired to be used to serve std::sort(ifnames_and_ips.begin(), ifnames_and_ips.end(), CompNamesAndIps); @@ -140,7 +145,8 @@ std::vector GetValidLocalIpCandidates() { instrumented_io_context io_context; boost::asio::ip::tcp::resolver resolver(io_context); boost::asio::ip::tcp::resolver::query query( - boost::asio::ip::host_name(), "", + boost::asio::ip::host_name(), + "", boost::asio::ip::resolver_query_base::flags::v4_mapped); boost::asio::ip::tcp::resolver::iterator iter = resolver.resolve(query); diff --git a/src/ray/common/network_util.h b/src/ray/common/network_util.h index 8f268ec46..c9aa8937a 100644 --- a/src/ray/common/network_util.h +++ b/src/ray/common/network_util.h @@ -57,7 +57,10 @@ class AsyncClient { /// \param is_timeout Whether connection timeout. /// \param error_code Set to indicate what error occurred, if any. /// \return Whether the connection is successful. - bool Connect(const std::string &ip, int port, int64_t timeout_ms, bool *is_timeout, + bool Connect(const std::string &ip, + int port, + int64_t timeout_ms, + bool *is_timeout, boost::system::error_code *error_code = nullptr) { try { auto endpoint = @@ -65,9 +68,11 @@ class AsyncClient { bool is_connected = false; *is_timeout = false; - socket_.async_connect(endpoint, boost::bind(&AsyncClient::ConnectHandle, this, - boost::asio::placeholders::error, - boost::ref(is_connected))); + socket_.async_connect(endpoint, + boost::bind(&AsyncClient::ConnectHandle, + this, + boost::asio::placeholders::error, + boost::ref(is_connected))); // Set a deadline for the asynchronous operation. timer_.expires_from_now(boost::posix_time::milliseconds(timeout_ms)); diff --git a/src/ray/common/placement_group.h b/src/ray/common/placement_group.h index 532f69d74..16c718b65 100644 --- a/src/ray/common/placement_group.h +++ b/src/ray/common/placement_group.h @@ -65,10 +65,13 @@ class PlacementGroupSpecBuilder { /// /// \return Reference to the builder object itself. PlacementGroupSpecBuilder &SetPlacementGroupSpec( - const PlacementGroupID &placement_group_id, std::string name, + const PlacementGroupID &placement_group_id, + std::string name, const std::vector> &bundles, - const rpc::PlacementStrategy strategy, const bool is_detached, - const JobID &creator_job_id, const ActorID &creator_actor_id, + const rpc::PlacementStrategy strategy, + const bool is_detached, + const JobID &creator_job_id, + const ActorID &creator_actor_id, bool is_creator_detached_actor) { message_->set_placement_group_id(placement_group_id.Binary()); message_->set_name(name); diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 21690198d..092e92c88 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -245,7 +245,8 @@ RAY_CONFIG(uint64_t, object_manager_default_chunk_size, 5 * 1024 * 1024) /// The maximum number of outbound bytes to allow to be outstanding. This avoids /// excessive memory usage during object broadcast to many receivers. -RAY_CONFIG(uint64_t, object_manager_max_bytes_in_flight, +RAY_CONFIG(uint64_t, + object_manager_max_bytes_in_flight, ((uint64_t)2) * 1024 * 1024 * 1024) /// Maximum number of ids in one batch to send to GCS to delete keys. @@ -478,7 +479,8 @@ RAY_CONFIG(bool, enable_light_weight_resource_report, true) // The number of seconds to wait for the Raylet to start. This is normally // fast, but when RAY_preallocate_plasma_memory=1 is set, it may take some time // (a few GB/s) to populate all the pages on Raylet startup. -RAY_CONFIG(uint32_t, raylet_start_wait_time_s, +RAY_CONFIG(uint32_t, + raylet_start_wait_time_s, std::getenv("RAY_preallocate_plasma_memory") != nullptr && std::getenv("RAY_preallocate_plasma_memory") == std::string("1") ? 120 diff --git a/src/ray/common/ray_object.cc b/src/ray/common/ray_object.cc index 0bfedd81f..2a1043436 100644 --- a/src/ray/common/ray_object.cc +++ b/src/ray/common/ray_object.cc @@ -74,7 +74,8 @@ std::shared_ptr MakeSerializedErrorBuffer( kMessagePackOffset); // copy msgpack-serialized bytes std::memcpy(final_buffer->Data() + kMessagePackOffset, - msgpack_serialized_exception.data(), msgpack_serialized_exception.size()); + msgpack_serialized_exception.data(), + msgpack_serialized_exception.size()); // copy offset msgpack::sbuffer msgpack_int; msgpack::pack(msgpack_int, msgpack_serialized_exception.size()); diff --git a/src/ray/common/ray_object.h b/src/ray/common/ray_object.h index d91d3fccb..6ae9d64de 100644 --- a/src/ray/common/ray_object.h +++ b/src/ray/common/ray_object.h @@ -39,7 +39,8 @@ class RayObject { /// \param[in] metadata Metadata of the ray object. /// \param[in] nested_rfs ObjectRefs that were serialized in data. /// \param[in] copy_data Whether this class should hold a copy of data. - RayObject(const std::shared_ptr &data, const std::shared_ptr &metadata, + RayObject(const std::shared_ptr &data, + const std::shared_ptr &metadata, const std::vector &nested_refs, bool copy_data = false) { Init(data, metadata, nested_refs, copy_data); @@ -125,7 +126,8 @@ class RayObject { int64_t CreationTimeNanos() const { return creation_time_nanos_; } private: - void Init(const std::shared_ptr &data, const std::shared_ptr &metadata, + void Init(const std::shared_ptr &data, + const std::shared_ptr &metadata, const std::vector &nested_refs, bool copy_data = false) { data_ = data; @@ -138,7 +140,8 @@ class RayObject { // If this object is required to hold a copy of the data, // make a copy if the passed in buffers don't already have a copy. if (data_ && !data_->OwnsData()) { - data_ = std::make_shared(data_->Data(), data_->Size(), + data_ = std::make_shared(data_->Data(), + data_->Size(), /*copy_data=*/true); } diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index f643ec018..57e990e8c 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -176,7 +176,8 @@ int TaskSpecification::GetRuntimeEnvHash() const { required_resource = GetRequiredResources().GetResourceMap(); } WorkerCacheKey env = { - SerializedRuntimeEnv(), required_resource, + SerializedRuntimeEnv(), + required_resource, IsActorCreationTask() && RayConfig::instance().isolate_workers_across_task_types(), GetRequiredResources().GetResource("GPU") > 0 && RayConfig::instance().isolate_workers_across_resource_types()}; @@ -451,7 +452,8 @@ std::string TaskSpecification::CallSiteString() const { WorkerCacheKey::WorkerCacheKey( const std::string serialized_runtime_env, - const absl::flat_hash_map &required_resources, bool is_actor, + const absl::flat_hash_map &required_resources, + bool is_actor, bool is_gpu) : serialized_runtime_env(serialized_runtime_env), required_resources(std::move(required_resources)), diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index fbf01bf6f..0f0efb596 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -87,7 +87,8 @@ struct ConcurrencyGroup { ConcurrencyGroup() = default; - ConcurrencyGroup(const std::string &name, uint32_t max_concurrency, + ConcurrencyGroup(const std::string &name, + uint32_t max_concurrency, const std::vector &fds) : name(name), max_concurrency(max_concurrency), function_descriptors(fds) {} @@ -347,7 +348,8 @@ class WorkerCacheKey { /// resource type isolation between workers is enabled. WorkerCacheKey(const std::string serialized_runtime_env, const absl::flat_hash_map &required_resources, - bool is_actor, bool is_gpu); + bool is_actor, + bool is_gpu); bool operator==(const WorkerCacheKey &k) const; diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index c043b4233..4f5305f9b 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -34,7 +34,8 @@ class TaskArgByReference : public TaskArg { /// /// \param[in] object_id Id of the argument. /// \return The task argument. - TaskArgByReference(const ObjectID &object_id, const rpc::Address &owner_address, + TaskArgByReference(const ObjectID &object_id, + const rpc::Address &owner_address, const std::string &call_site) : id_(object_id), owner_address_(owner_address), call_site_(call_site) {} @@ -97,13 +98,20 @@ class TaskSpecBuilder { /// /// \return Reference to the builder object itself. TaskSpecBuilder &SetCommonTaskSpec( - const TaskID &task_id, const std::string name, const Language &language, - const ray::FunctionDescriptor &function_descriptor, const JobID &job_id, - const TaskID &parent_task_id, uint64_t parent_counter, const TaskID &caller_id, - const rpc::Address &caller_address, uint64_t num_returns, + const TaskID &task_id, + const std::string name, + const Language &language, + const ray::FunctionDescriptor &function_descriptor, + const JobID &job_id, + const TaskID &parent_task_id, + uint64_t parent_counter, + const TaskID &caller_id, + const rpc::Address &caller_address, + uint64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const std::string &debugger_breakpoint, int64_t depth, + const std::string &debugger_breakpoint, + int64_t depth, const std::shared_ptr runtime_env_info = nullptr, const std::string &concurrency_group_name = "") { message_->set_type(TaskType::NORMAL_TASK); @@ -130,7 +138,8 @@ class TaskSpecBuilder { return *this; } - TaskSpecBuilder &SetNormalTaskSpec(int max_retries, bool retry_exceptions, + TaskSpecBuilder &SetNormalTaskSpec(int max_retries, + bool retry_exceptions, const rpc::SchedulingStrategy &scheduling_strategy) { message_->set_max_retries(max_retries); message_->set_retry_exceptions(retry_exceptions); @@ -142,8 +151,10 @@ class TaskSpecBuilder { /// See `common.proto` for meaning of the arguments. /// /// \return Reference to the builder object itself. - TaskSpecBuilder &SetDriverTaskSpec(const TaskID &task_id, const Language &language, - const JobID &job_id, const TaskID &parent_task_id, + TaskSpecBuilder &SetDriverTaskSpec(const TaskID &task_id, + const Language &language, + const JobID &job_id, + const TaskID &parent_task_id, const TaskID &caller_id, const rpc::Address &caller_address) { message_->set_type(TaskType::DRIVER_TASK); @@ -170,14 +181,20 @@ class TaskSpecBuilder { /// /// \return Reference to the builder object itself. TaskSpecBuilder &SetActorCreationTaskSpec( - const ActorID &actor_id, const std::string &serialized_actor_handle, - const rpc::SchedulingStrategy &scheduling_strategy, int64_t max_restarts = 0, + const ActorID &actor_id, + const std::string &serialized_actor_handle, + const rpc::SchedulingStrategy &scheduling_strategy, + int64_t max_restarts = 0, int64_t max_task_retries = 0, const std::vector &dynamic_worker_options = {}, - int max_concurrency = 1, bool is_detached = false, std::string name = "", - std::string ray_namespace = "", bool is_asyncio = false, + int max_concurrency = 1, + bool is_detached = false, + std::string name = "", + std::string ray_namespace = "", + bool is_asyncio = false, const std::vector &concurrency_groups = {}, - const std::string &extension_data = "", bool execute_out_of_order = false) { + const std::string &extension_data = "", + bool execute_out_of_order = false) { message_->set_type(TaskType::ACTOR_CREATION_TASK); auto actor_creation_spec = message_->mutable_actor_creation_task_spec(); actor_creation_spec->set_actor_id(actor_id.Binary()); diff --git a/src/ray/common/test/client_connection_test.cc b/src/ray/common/test/client_connection_test.cc index 8db2ed8c9..79808dd38 100644 --- a/src/ray/common/test/client_connection_test.cc +++ b/src/ray/common/test/client_connection_test.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "ray/common/client_connection.h" -#include "ray/common/asio/instrumented_io_context.h" #include #include @@ -22,6 +21,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "ray/common/asio/instrumented_io_context.h" namespace ray { namespace raylet { @@ -45,8 +45,10 @@ class ClientConnectionTest : public ::testing::Test { #endif } - ray::Status WriteBadMessage(std::shared_ptr conn, int64_t type, - int64_t length, const uint8_t *message) { + ray::Status WriteBadMessage(std::shared_ptr conn, + int64_t type, + int64_t length, + const uint8_t *message) { std::vector message_buffers; auto write_cookie = 123456; // incorrect version. message_buffers.push_back(boost::asio::buffer(&write_cookie, sizeof(write_cookie))); @@ -69,18 +71,19 @@ TEST_F(ClientConnectionTest, SimpleSyncWrite) { ClientHandler client_handler = [](ClientConnection &client) {}; - MessageHandler message_handler = - [&arr, &num_messages](std::shared_ptr client, - int64_t message_type, const std::vector &message) { - ASSERT_TRUE(!std::memcmp(arr, message.data(), 5)); - num_messages += 1; - }; + MessageHandler message_handler = [&arr, &num_messages]( + std::shared_ptr client, + int64_t message_type, + const std::vector &message) { + ASSERT_TRUE(!std::memcmp(arr, message.data(), 5)); + num_messages += 1; + }; - auto conn1 = ClientConnection::Create(client_handler, message_handler, std::move(in_), - "conn1", {}, error_message_type_); + auto conn1 = ClientConnection::Create( + client_handler, message_handler, std::move(in_), "conn1", {}, error_message_type_); - auto conn2 = ClientConnection::Create(client_handler, message_handler, std::move(out_), - "conn2", {}, error_message_type_); + auto conn2 = ClientConnection::Create( + client_handler, message_handler, std::move(out_), "conn2", {}, error_message_type_); RAY_CHECK_OK(conn1->WriteMessage(0, 5, arr)); RAY_CHECK_OK(conn2->WriteMessage(0, 5, arr)); @@ -121,11 +124,15 @@ TEST_F(ClientConnectionTest, SimpleAsyncWrite) { } }; - auto writer = ClientConnection::Create(client_handler, noop_handler, std::move(in_), - "writer", {}, error_message_type_); + auto writer = ClientConnection::Create( + client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_); - reader = ClientConnection::Create(client_handler, message_handler, std::move(out_), - "reader", {}, error_message_type_); + reader = ClientConnection::Create(client_handler, + message_handler, + std::move(out_), + "reader", + {}, + error_message_type_); std::function callback = [](const ray::Status &status) { RAY_CHECK_OK(status); @@ -178,8 +185,8 @@ TEST_F(ClientConnectionTest, SimpleAsyncError) { int64_t message_type, const std::vector &message) {}; - auto writer = ClientConnection::Create(client_handler, noop_handler, std::move(in_), - "writer", {}, error_message_type_); + auto writer = ClientConnection::Create( + client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_); std::function callback = [](const ray::Status &status) { ASSERT_TRUE(!status.ok()); @@ -199,8 +206,8 @@ TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) { int64_t message_type, const std::vector &message) {}; - auto writer = ClientConnection::Create(client_handler, noop_handler, std::move(in_), - "writer", {}, error_message_type_); + auto writer = ClientConnection::Create( + client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_); std::function callback = [writer](const ray::Status &status) { @@ -217,18 +224,23 @@ TEST_F(ClientConnectionTest, ProcessBadMessage) { ClientHandler client_handler = [](ClientConnection &client) {}; - MessageHandler message_handler = - [&arr, &num_messages](std::shared_ptr client, - int64_t message_type, const std::vector &message) { - ASSERT_TRUE(!std::memcmp(arr, message.data(), 5)); - num_messages += 1; - }; + MessageHandler message_handler = [&arr, &num_messages]( + std::shared_ptr client, + int64_t message_type, + const std::vector &message) { + ASSERT_TRUE(!std::memcmp(arr, message.data(), 5)); + num_messages += 1; + }; - auto writer = ClientConnection::Create(client_handler, message_handler, std::move(in_), - "writer", {}, error_message_type_); + auto writer = ClientConnection::Create( + client_handler, message_handler, std::move(in_), "writer", {}, error_message_type_); - auto reader = ClientConnection::Create(client_handler, message_handler, std::move(out_), - "reader", {}, error_message_type_); + auto reader = ClientConnection::Create(client_handler, + message_handler, + std::move(out_), + "reader", + {}, + error_message_type_); // If client ID is set, bad message would crash the test. // reader->SetClientID(UniqueID::FromRandom()); diff --git a/src/ray/common/test_util.cc b/src/ray/common/test_util.cc index 8e7dfcfd0..921ca96c0 100644 --- a/src/ray/common/test_util.cc +++ b/src/ray/common/test_util.cc @@ -100,7 +100,8 @@ std::string TestSetupUtil::StartGcsServer(const std::string &redis_address) { std::string gcs_server_socket_name = ray::JoinPaths(ray::GetUserTempDir(), "gcs_server" + ObjectID::FromRandom().Hex()); std::vector cmdargs( - {TEST_GCS_SERVER_EXEC_PATH, "--redis_address=" + redis_address, + {TEST_GCS_SERVER_EXEC_PATH, + "--redis_address=" + redis_address, "--config_list=" + absl::Base64Escape(R"({"object_timeout_milliseconds": 2000})")}); if (RayConfig::instance().bootstrap_with_gcs()) { @@ -129,15 +130,21 @@ std::string TestSetupUtil::StartRaylet(const std::string &node_ip_address, std::string plasma_store_socket_name = ray::JoinPaths(ray::GetUserTempDir(), "store" + ObjectID::FromRandom().Hex()); std::vector cmdargs( - {TEST_RAYLET_EXEC_PATH, "--raylet_socket_name=" + raylet_socket_name, - "--store_socket_name=" + plasma_store_socket_name, "--object_manager_port=0", + {TEST_RAYLET_EXEC_PATH, + "--raylet_socket_name=" + raylet_socket_name, + "--store_socket_name=" + plasma_store_socket_name, + "--object_manager_port=0", "--node_manager_port=" + std::to_string(port), - "--node_ip_address=" + node_ip_address, "--redis_port=6379", "--min-worker-port=0", - "--max-worker-port=0", "--maximum_startup_concurrency=10", + "--node_ip_address=" + node_ip_address, + "--redis_port=6379", + "--min-worker-port=0", + "--max-worker-port=0", + "--maximum_startup_concurrency=10", "--static_resource_list=" + resource, - "--python_worker_command=" + - CreateCommandLine({TEST_MOCK_WORKER_EXEC_PATH, plasma_store_socket_name, - raylet_socket_name, std::to_string(port)}), + "--python_worker_command=" + CreateCommandLine({TEST_MOCK_WORKER_EXEC_PATH, + plasma_store_socket_name, + raylet_socket_name, + std::to_string(port)}), "--object_store_memory=10000000"}); if (RayConfig::instance().bootstrap_with_gcs()) { cmdargs.push_back("--gcs-address=" + bootstrap_address); @@ -178,7 +185,8 @@ bool WaitForCondition(std::function condition, int timeout_ms) { return false; } -void WaitForExpectedCount(std::atomic ¤t_count, int expected_count, +void WaitForExpectedCount(std::atomic ¤t_count, + int expected_count, int timeout_ms) { auto condition = [¤t_count, expected_count]() { return current_count == expected_count; diff --git a/src/ray/common/test_util.h b/src/ray/common/test_util.h index 120d818c7..b8225090f 100644 --- a/src/ray/common/test_util.h +++ b/src/ray/common/test_util.h @@ -63,7 +63,8 @@ bool WaitForCondition(std::function condition, int timeout_ms); /// \param[in] expected_count The expected count. /// \param[in] timeout_ms Timeout in milliseconds to wait for for. /// \return Whether the expected count is met. -void WaitForExpectedCount(std::atomic ¤t_count, int expected_count, +void WaitForExpectedCount(std::atomic ¤t_count, + int expected_count, int timeout_ms = 60000); /// Used to kill process whose pid is stored in `socket_name.id` file. @@ -117,7 +118,8 @@ class TestSetupUtil { static std::string StartGcsServer(const std::string &redis_address); static void StopGcsServer(const std::string &gcs_server_socket_name); - static std::string StartRaylet(const std::string &node_ip_address, const int &port, + static std::string StartRaylet(const std::string &node_ip_address, + const int &port, const std::string &bootstrap_address, const std::string &resource, std::string *store_socket_name); diff --git a/src/ray/core_worker/actor_handle.cc b/src/ray/core_worker/actor_handle.cc index 741108672..d012f074f 100644 --- a/src/ray/core_worker/actor_handle.cc +++ b/src/ray/core_worker/actor_handle.cc @@ -20,12 +20,18 @@ namespace ray { namespace core { namespace { rpc::ActorHandle CreateInnerActorHandle( - const class ActorID &actor_id, const TaskID &owner_id, - const rpc::Address &owner_address, const class JobID &job_id, - const ObjectID &initial_cursor, const Language actor_language, + const class ActorID &actor_id, + const TaskID &owner_id, + const rpc::Address &owner_address, + const class JobID &job_id, + const ObjectID &initial_cursor, + const Language actor_language, const FunctionDescriptor &actor_creation_task_function_descriptor, - const std::string &extension_data, int64_t max_task_retries, const std::string &name, - const std::string &ray_namespace, int32_t max_pending_calls, + const std::string &extension_data, + int64_t max_task_retries, + const std::string &name, + const std::string &ray_namespace, + int32_t max_pending_calls, bool execute_out_of_order) { rpc::ActorHandle inner; inner.set_actor_id(actor_id.Data(), actor_id.Size()); @@ -79,17 +85,32 @@ rpc::ActorHandle CreateInnerActorHandleFromActorTableData( } // namespace ActorHandle::ActorHandle( - const class ActorID &actor_id, const TaskID &owner_id, - const rpc::Address &owner_address, const class JobID &job_id, - const ObjectID &initial_cursor, const Language actor_language, + const class ActorID &actor_id, + const TaskID &owner_id, + const rpc::Address &owner_address, + const class JobID &job_id, + const ObjectID &initial_cursor, + const Language actor_language, const FunctionDescriptor &actor_creation_task_function_descriptor, - const std::string &extension_data, int64_t max_task_retries, const std::string &name, - const std::string &ray_namespace, int32_t max_pending_calls, + const std::string &extension_data, + int64_t max_task_retries, + const std::string &name, + const std::string &ray_namespace, + int32_t max_pending_calls, bool execute_out_of_order) - : ActorHandle(CreateInnerActorHandle( - actor_id, owner_id, owner_address, job_id, initial_cursor, actor_language, - actor_creation_task_function_descriptor, extension_data, max_task_retries, name, - ray_namespace, max_pending_calls, execute_out_of_order)) {} + : ActorHandle(CreateInnerActorHandle(actor_id, + owner_id, + owner_address, + job_id, + initial_cursor, + actor_language, + actor_creation_task_function_descriptor, + extension_data, + max_task_retries, + name, + ray_namespace, + max_pending_calls, + execute_out_of_order)) {} ActorHandle::ActorHandle(const std::string &serialized) : ActorHandle(CreateInnerActorHandleFromString(serialized)) {} @@ -103,7 +124,8 @@ void ActorHandle::SetActorTaskSpec(TaskSpecBuilder &builder, const ObjectID new_ const TaskID actor_creation_task_id = TaskID::ForActorCreationTask(GetActorID()); const ObjectID actor_creation_dummy_object_id = ObjectID::FromIndex(actor_creation_task_id, /*index=*/1); - builder.SetActorTaskSpec(GetActorID(), actor_creation_dummy_object_id, + builder.SetActorTaskSpec(GetActorID(), + actor_creation_dummy_object_id, /*previous_actor_task_dummy_object_id=*/actor_cursor_, task_counter_++); actor_cursor_ = new_cursor; diff --git a/src/ray/core_worker/actor_handle.h b/src/ray/core_worker/actor_handle.h index 56ecd5b7e..0b43c55e3 100644 --- a/src/ray/core_worker/actor_handle.h +++ b/src/ray/core_worker/actor_handle.h @@ -32,13 +32,19 @@ class ActorHandle { : inner_(inner), actor_cursor_(ObjectID::FromBinary(inner_.actor_cursor())) {} // Constructs a new ActorHandle as part of the actor creation process. - ActorHandle(const ActorID &actor_id, const TaskID &owner_id, - const rpc::Address &owner_address, const JobID &job_id, - const ObjectID &initial_cursor, const Language actor_language, + ActorHandle(const ActorID &actor_id, + const TaskID &owner_id, + const rpc::Address &owner_address, + const JobID &job_id, + const ObjectID &initial_cursor, + const Language actor_language, const FunctionDescriptor &actor_creation_task_function_descriptor, - const std::string &extension_data, int64_t max_task_retries, - const std::string &name, const std::string &ray_namespace, - int32_t max_pending_calls, bool execute_out_of_order = false); + const std::string &extension_data, + int64_t max_task_retries, + const std::string &name, + const std::string &ray_namespace, + int32_t max_pending_calls, + bool execute_out_of_order = false); /// Constructs an ActorHandle from a serialized string. explicit ActorHandle(const std::string &serialized); diff --git a/src/ray/core_worker/actor_manager.cc b/src/ray/core_worker/actor_manager.cc index e136c53b0..5b35f5d36 100644 --- a/src/ray/core_worker/actor_manager.cc +++ b/src/ray/core_worker/actor_manager.cc @@ -30,9 +30,14 @@ ActorID ActorManager::RegisterActorHandle(std::unique_ptr actor_han // Note we need set `cached_actor_name` to empty string as we only cache named actors // when getting them from GCS. - RAY_UNUSED(AddActorHandle(std::move(actor_handle), /*cached_actor_name=*/"", - /*is_owner_handle=*/false, call_site, caller_address, - actor_id, actor_creation_return_id, is_self)); + RAY_UNUSED(AddActorHandle(std::move(actor_handle), + /*cached_actor_name=*/"", + /*is_owner_handle=*/false, + call_site, + caller_address, + actor_id, + actor_creation_return_id, + is_self)); ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id); reference_counter_->AddBorrowedObject(actor_handle_id, outer_object_id, owner_address); return actor_id; @@ -48,8 +53,10 @@ std::shared_ptr ActorManager::GetActorHandle(const ActorID &actor_i } std::pair, Status> ActorManager::GetNamedActorHandle( - const std::string &name, const std::string &ray_namespace, - const std::string &call_site, const rpc::Address &caller_address) { + const std::string &name, + const std::string &ray_namespace, + const std::string &call_site, + const rpc::Address &caller_address) { ActorID actor_id = GetCachedNamedActorID(GenerateCachedActorName(ray_namespace, name)); if (actor_id.IsNil()) { // This call needs to be blocking because we can't return until the actor @@ -64,7 +71,8 @@ std::pair, Status> ActorManager::GetNamedActo actor_id = actor_handle->GetActorID(); AddNewActorHandle(std::move(actor_handle), GenerateCachedActorName(result.ray_namespace(), result.name()), - call_site, caller_address, + call_site, + caller_address, /*is_detached*/ true); } else { // Use a NIL actor ID to signal that the actor wasn't found. @@ -114,28 +122,38 @@ bool ActorManager::AddNewActorHandle(std::unique_ptr actor_handle, // We don't need to add an initial local ref here because it will get added // in AddActorHandle. reference_counter_->AddOwnedObject(actor_creation_return_id, - /*inner_ids=*/{}, caller_address, call_site, + /*inner_ids=*/{}, + caller_address, + call_site, /*object_size*/ -1, /*is_reconstructable=*/true, /*add_local_ref=*/false); } - return AddActorHandle(std::move(actor_handle), cached_actor_name, - /*is_owner_handle=*/!is_detached, call_site, caller_address, - actor_id, actor_creation_return_id); + return AddActorHandle(std::move(actor_handle), + cached_actor_name, + /*is_owner_handle=*/!is_detached, + call_site, + caller_address, + actor_id, + actor_creation_return_id); } bool ActorManager::AddNewActorHandle(std::unique_ptr actor_handle, const std::string &call_site, const rpc::Address &caller_address, bool is_detached) { - return AddNewActorHandle(std::move(actor_handle), /*cached_actor_name=*/"", call_site, - caller_address, is_detached); + return AddNewActorHandle(std::move(actor_handle), + /*cached_actor_name=*/"", + call_site, + caller_address, + is_detached); } bool ActorManager::AddActorHandle(std::unique_ptr actor_handle, const std::string &cached_actor_name, - bool is_owner_handle, const std::string &call_site, + bool is_owner_handle, + const std::string &call_site, const rpc::Address &caller_address, const ActorID &actor_id, const ObjectID &actor_creation_return_id, @@ -160,10 +178,13 @@ bool ActorManager::AddActorHandle(std::unique_ptr actor_handle, if (inserted) { // Register a callback to handle actor notifications. auto actor_notification_callback = - std::bind(&ActorManager::HandleActorStateNotification, this, - std::placeholders::_1, std::placeholders::_2); + std::bind(&ActorManager::HandleActorStateNotification, + this, + std::placeholders::_1, + std::placeholders::_2); RAY_CHECK_OK(gcs_client_->Actors().AsyncSubscribe( - actor_id, actor_notification_callback, + actor_id, + actor_notification_callback, [this, actor_id, cached_actor_name](Status status) { if (status.ok() && !cached_actor_name.empty()) { { @@ -229,18 +250,22 @@ void ActorManager::HandleActorStateNotification(const ActorID &actor_id, << ", death context type=" << gcs::GetActorDeathCauseString(actor_data.death_cause()); if (actor_data.state() == rpc::ActorTableData::RESTARTING) { - direct_actor_submitter_->DisconnectActor(actor_id, actor_data.num_restarts(), - /*is_dead=*/false, actor_data.death_cause()); + direct_actor_submitter_->DisconnectActor(actor_id, + actor_data.num_restarts(), + /*is_dead=*/false, + actor_data.death_cause()); } else if (actor_data.state() == rpc::ActorTableData::DEAD) { OnActorKilled(actor_id); - direct_actor_submitter_->DisconnectActor(actor_id, actor_data.num_restarts(), - /*is_dead=*/true, actor_data.death_cause()); + direct_actor_submitter_->DisconnectActor(actor_id, + actor_data.num_restarts(), + /*is_dead=*/true, + actor_data.death_cause()); // We cannot erase the actor handle here because clients can still // submit tasks to dead actors. This also means we defer unsubscription, // otherwise we crash when bulk unsubscribing all actor handles. } else if (actor_data.state() == rpc::ActorTableData::ALIVE) { - direct_actor_submitter_->ConnectActor(actor_id, actor_data.address(), - actor_data.num_restarts()); + direct_actor_submitter_->ConnectActor( + actor_id, actor_data.address(), actor_data.num_restarts()); } else { // The actor is being created and not yet ready, just ignore! } diff --git a/src/ray/core_worker/actor_manager.h b/src/ray/core_worker/actor_manager.h index 86056e15c..0b83645c3 100644 --- a/src/ray/core_worker/actor_manager.h +++ b/src/ray/core_worker/actor_manager.h @@ -57,7 +57,8 @@ class ActorManager { ActorID RegisterActorHandle(std::unique_ptr actor_handle, const ObjectID &outer_object_id, const std::string &call_site, - const rpc::Address &caller_address, bool is_self = false); + const rpc::Address &caller_address, + bool is_self = false); /// Get a handle to an actor. /// @@ -76,8 +77,10 @@ class ActorManager { /// \param[in] caller_address The rpc address of the calling task. /// \return KV pair of actor handle pointer and status. std::pair, Status> GetNamedActorHandle( - const std::string &name, const std::string &ray_namespace, - const std::string &call_site, const rpc::Address &caller_address); + const std::string &name, + const std::string &ray_namespace, + const std::string &call_site, + const rpc::Address &caller_address); /// Check if an actor handle that corresponds to an actor_id exists. /// \param[in] actor_id The actor id of a handle. @@ -100,7 +103,8 @@ class ActorManager { /// actor. \return True if the handle was added and False if we already had a handle to /// the same actor. bool AddNewActorHandle(std::unique_ptr actor_handle, - const std::string &call_site, const rpc::Address &caller_address, + const std::string &call_site, + const rpc::Address &caller_address, bool is_detached); /// Wait for actor out of scope. @@ -124,7 +128,8 @@ class ActorManager { private: bool AddNewActorHandle(std::unique_ptr actor_handle, const std::string &cached_actor_name, - const std::string &call_site, const rpc::Address &caller_address, + const std::string &call_site, + const rpc::Address &caller_address, bool is_detached); /// Give this worker a handle to an actor. @@ -148,9 +153,12 @@ class ActorManager { /// \return True if the handle was added and False if we already had a handle /// to the same actor. bool AddActorHandle(std::unique_ptr actor_handle, - const std::string &cached_actor_name, bool is_owner_handle, - const std::string &call_site, const rpc::Address &caller_address, - const ActorID &actor_id, const ObjectID &actor_creation_return_id, + const std::string &cached_actor_name, + bool is_owner_handle, + const std::string &call_site, + const rpc::Address &caller_address, + const ActorID &actor_id, + const ObjectID &actor_creation_return_id, bool is_self = false); /// Check if named actor is cached locally. diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 64ce5fa55..0460494f3 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -56,7 +56,8 @@ class RayFunction { /// Options for all tasks (actor and non-actor) except for actor creation. struct TaskOptions { TaskOptions() {} - TaskOptions(std::string name, int num_returns, + TaskOptions(std::string name, + int num_returns, std::unordered_map &resources, const std::string &concurrency_group_name = "", const std::string &serialized_runtime_env_info = "{}") @@ -83,17 +84,21 @@ struct TaskOptions { /// Options for actor creation tasks. struct ActorCreationOptions { ActorCreationOptions() {} - ActorCreationOptions(int64_t max_restarts, int64_t max_task_retries, + ActorCreationOptions(int64_t max_restarts, + int64_t max_task_retries, int max_concurrency, const std::unordered_map &resources, const std::unordered_map &placement_resources, const std::vector &dynamic_worker_options, - std::optional is_detached, std::string &name, - std::string &ray_namespace, bool is_asyncio, + std::optional is_detached, + std::string &name, + std::string &ray_namespace, + bool is_asyncio, const rpc::SchedulingStrategy &scheduling_strategy, const std::string &serialized_runtime_env_info = "{}", const std::vector &concurrency_groups = {}, - bool execute_out_of_order = false, int32_t max_pending_calls = -1) + bool execute_out_of_order = false, + int32_t max_pending_calls = -1) : max_restarts(max_restarts), max_task_retries(max_task_retries), max_concurrency(max_concurrency), @@ -159,8 +164,10 @@ using PlacementStrategy = rpc::PlacementStrategy; struct PlacementGroupCreationOptions { PlacementGroupCreationOptions( - std::string name, PlacementStrategy strategy, - std::vector> bundles, bool is_detached) + std::string name, + PlacementStrategy strategy, + std::vector> bundles, + bool is_detached) : name(std::move(name)), strategy(strategy), bundles(std::move(bundles)), @@ -178,8 +185,11 @@ struct PlacementGroupCreationOptions { class ObjectLocation { public: - ObjectLocation(NodeID primary_node_id, uint64_t object_size, - std::vector node_ids, bool is_spilled, std::string spilled_url, + ObjectLocation(NodeID primary_node_id, + uint64_t object_size, + std::vector node_ids, + bool is_spilled, + std::string spilled_url, NodeID spilled_node_id) : primary_node_id_(primary_node_id), object_size_(object_size), diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index d71f312b0..462439f52 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -136,7 +136,8 @@ struct WorkerThreadContext { thread_local std::unique_ptr WorkerContext::thread_context_ = nullptr; -WorkerContext::WorkerContext(WorkerType worker_type, const WorkerID &worker_id, +WorkerContext::WorkerContext(WorkerType worker_type, + const WorkerID &worker_id, const JobID &job_id) : worker_type_(worker_type), worker_id_(worker_id), diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index d301f2865..8e5448c0d 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -63,7 +63,9 @@ ObjectLocation CreateObjectLocation(const rpc::GetObjectLocationsOwnerReply &rep } bool is_spilled = !object_info.spilled_url().empty(); return ObjectLocation(NodeID::FromBinary(object_info.primary_node_id()), - object_info.object_size(), std::move(node_ids), is_spilled, + object_info.object_size(), + std::move(node_ids), + is_spilled, object_info.spilled_url(), NodeID::FromBinary(object_info.spilled_node_id())); } @@ -88,12 +90,17 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Initialize task receivers. if (options_.worker_type == WorkerType::WORKER || options_.is_local_mode) { RAY_CHECK(options_.task_execution_callback != nullptr); - auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1, - std::placeholders::_2, std::placeholders::_3, - std::placeholders::_4, std::placeholders::_5); + auto execute_task = std::bind(&CoreWorker::ExecuteTask, + this, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3, + std::placeholders::_4, + std::placeholders::_5); direct_task_receiver_ = std::make_unique( - worker_context_, task_execution_service_, execute_task, - [this] { return local_raylet_client_->TaskDone(); }); + worker_context_, task_execution_service_, execute_task, [this] { + return local_raylet_client_->TaskDone(); + }); } // Initialize raylet client. @@ -109,11 +116,21 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ NodeID local_raylet_id; int assigned_port; std::string serialized_job_config = options_.serialized_job_config; - local_raylet_client_ = std::make_shared( - io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(), - options_.worker_type, worker_context_.GetCurrentJobID(), options_.runtime_env_hash, - options_.language, options_.node_ip_address, &raylet_client_status, - &local_raylet_id, &assigned_port, &serialized_job_config, options_.startup_token); + local_raylet_client_ = + std::make_shared(io_service_, + std::move(grpc_client), + options_.raylet_socket, + GetWorkerID(), + options_.worker_type, + worker_context_.GetCurrentJobID(), + options_.runtime_env_hash, + options_.language, + options_.node_ip_address, + &raylet_client_status, + &local_raylet_id, + &assigned_port, + &serialized_job_config, + options_.startup_token); if (!raylet_client_status.ok()) { // Avoid using FATAL log or RAY_CHECK here because they may create a core dump file. @@ -142,9 +159,10 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Start RPC server after all the task receivers are properly initialized and we have // our assigned port from the raylet. - core_worker_server_ = std::make_unique( - WorkerTypeString(options_.worker_type), assigned_port, - options_.node_ip_address == "127.0.0.1"); + core_worker_server_ = + std::make_unique(WorkerTypeString(options_.worker_type), + assigned_port, + options_.node_ip_address == "127.0.0.1"); core_worker_server_->RegisterService(grpc_service_); core_worker_server_->Run(); @@ -160,7 +178,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Begin to get gcs server address from raylet. gcs_server_address_updater_ = std::make_unique( - options_.raylet_ip_address, options_.node_manager_port, + options_.raylet_ip_address, + options_.node_manager_port, [this](std::string ip, int port) { absl::MutexLock lock(&gcs_server_address_mutex_); gcs_server_address_.first = ip; @@ -225,8 +244,10 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ reference_counter_ = std::make_shared( rpc_address_, /*object_info_publisher=*/object_info_publisher_.get(), - /*object_info_subscriber=*/object_info_subscriber_.get(), check_node_alive_fn, - RayConfig::instance().lineage_pinning_enabled(), [this](const rpc::Address &addr) { + /*object_info_subscriber=*/object_info_subscriber_.get(), + check_node_alive_fn, + RayConfig::instance().lineage_pinning_enabled(), + [this](const rpc::Address &addr) { return std::shared_ptr( new rpc::CoreWorkerClient(addr, *client_call_manager_)); }); @@ -238,14 +259,18 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ } plasma_store_provider_.reset(new CoreWorkerPlasmaStoreProvider( - options_.store_socket, local_raylet_client_, reference_counter_, + options_.store_socket, + local_raylet_client_, + reference_counter_, options_.check_signals, /*warmup=*/ (options_.worker_type != WorkerType::SPILL_WORKER && options_.worker_type != WorkerType::RESTORE_WORKER), /*get_current_call_site=*/boost::bind(&CoreWorker::CurrentCallSite, this))); memory_store_.reset(new CoreWorkerMemoryStore( - reference_counter_, local_raylet_client_, options_.check_signals, + reference_counter_, + local_raylet_client_, + options_.check_signals, [this](const RayObject &obj) { // Run this on the event loop to avoid calling back into the language runtime // from the middle of user operations. @@ -261,12 +286,15 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ periodical_runner_.RunFnPeriodically([this] { InternalHeartbeat(); }, kInternalHeartbeatMillis); - auto push_error_callback = [this](const JobID &job_id, const std::string &type, - const std::string &error_message, double timestamp) { + auto push_error_callback = [this](const JobID &job_id, + const std::string &type, + const std::string &error_message, + double timestamp) { return PushError(job_id, type, error_message, timestamp); }; task_manager_.reset(new TaskManager( - memory_store_, reference_counter_, + memory_store_, + reference_counter_, /*put_in_local_plasma_callback=*/ [this](const RayObject &object, const ObjectID &object_id) { RAY_CHECK_OK(PutInLocalPlasmaStore(object, object_id, /*pin_object=*/true)); @@ -294,7 +322,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ } } }, - push_error_callback, RayConfig::instance().max_lineage_bytes())); + push_error_callback, + RayConfig::instance().max_lineage_bytes())); // Create an entry for the driver task in the task table. This task is // added immediately with status RUNNING. This allows us to push errors @@ -305,10 +334,12 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ if (options_.worker_type == WorkerType::DRIVER) { TaskSpecBuilder builder; const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID()); - builder.SetDriverTaskSpec(task_id, options_.language, + builder.SetDriverTaskSpec(task_id, + options_.language, worker_context_.GetCurrentJobID(), TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()), - GetCallerId(), rpc_address_); + GetCallerId(), + rpc_address_); // Drivers are never re-executed. SetCurrentTaskId(task_id, /*attempt_number=*/0); } @@ -336,9 +367,12 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ actor_creator_ = std::make_shared(gcs_client_); direct_actor_submitter_ = std::shared_ptr( - new CoreWorkerDirectActorTaskSubmitter(*core_worker_client_pool_, *memory_store_, - *task_manager_, *actor_creator_, - on_excess_queueing, io_service_)); + new CoreWorkerDirectActorTaskSubmitter(*core_worker_client_pool_, + *memory_store_, + *task_manager_, + *actor_creator_, + on_excess_queueing, + io_service_)); auto node_addr_factory = [this](const NodeID &node_id) { absl::optional addr; @@ -359,30 +393,41 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ std::make_shared(rpc_address_)); direct_task_submitter_ = std::make_unique( - rpc_address_, local_raylet_client_, core_worker_client_pool_, raylet_client_factory, - std::move(lease_policy), memory_store_, task_manager_, local_raylet_id, - GetWorkerType(), RayConfig::instance().worker_lease_timeout_milliseconds(), - actor_creator_, worker_context_.GetCurrentJobID(), + rpc_address_, + local_raylet_client_, + core_worker_client_pool_, + raylet_client_factory, + std::move(lease_policy), + memory_store_, + task_manager_, + local_raylet_id, + GetWorkerType(), + RayConfig::instance().worker_lease_timeout_milliseconds(), + actor_creator_, + worker_context_.GetCurrentJobID(), boost::asio::steady_timer(io_service_), RayConfig::instance().max_pending_lease_requests_per_scheduling_category()); - auto report_locality_data_callback = - [this](const ObjectID &object_id, const absl::flat_hash_set &locations, - uint64_t object_size) { - reference_counter_->ReportLocalityData(object_id, locations, object_size); - }; - future_resolver_.reset(new FutureResolver(memory_store_, reference_counter_, + auto report_locality_data_callback = [this]( + const ObjectID &object_id, + const absl::flat_hash_set &locations, + uint64_t object_size) { + reference_counter_->ReportLocalityData(object_id, locations, object_size); + }; + future_resolver_.reset(new FutureResolver(memory_store_, + reference_counter_, std::move(report_locality_data_callback), - core_worker_client_pool_, rpc_address_)); + core_worker_client_pool_, + rpc_address_)); // Unfortunately the raylet client has to be constructed after the receivers. if (direct_task_receiver_ != nullptr) { task_argument_waiter_.reset(new DependencyWaiterImpl(*local_raylet_client_)); - direct_task_receiver_->Init(core_worker_client_pool_, rpc_address_, - task_argument_waiter_); + direct_task_receiver_->Init( + core_worker_client_pool_, rpc_address_, task_argument_waiter_); } - actor_manager_ = std::make_unique(gcs_client_, direct_actor_submitter_, - reference_counter_); + actor_manager_ = std::make_unique( + gcs_client_, direct_actor_submitter_, reference_counter_); std::function object_lookup_fn; @@ -411,13 +456,19 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ return Status::OK(); }; object_recovery_manager_ = std::make_unique( - rpc_address_, raylet_client_factory, local_raylet_client_, object_lookup_fn, - task_manager_, reference_counter_, memory_store_, + rpc_address_, + raylet_client_factory, + local_raylet_client_, + object_lookup_fn, + task_manager_, + reference_counter_, + memory_store_, [this](const ObjectID &object_id, rpc::ErrorType reason, bool pin_object) { RAY_LOG(DEBUG) << "Failed to recover object " << object_id << " due to " << rpc::ErrorType_Name(reason); RAY_CHECK_OK(Put(RayObject(reason), - /*contained_object_ids=*/{}, object_id, + /*contained_object_ids=*/{}, + object_id, /*pin_object=*/pin_object)); }); @@ -777,7 +828,8 @@ std::vector CoreWorker::GetObjectRefs( return refs; } -void CoreWorker::GetOwnershipInfo(const ObjectID &object_id, rpc::Address *owner_address, +void CoreWorker::GetOwnershipInfo(const ObjectID &object_id, + rpc::Address *owner_address, std::string *serialized_object_status) { auto has_owner = reference_counter_->GetOwner(object_id, owner_address); RAY_CHECK(has_owner) @@ -801,8 +853,10 @@ void CoreWorker::GetOwnershipInfo(const ObjectID &object_id, rpc::Address *owner } void CoreWorker::RegisterOwnershipInfoAndResolveFuture( - const ObjectID &object_id, const ObjectID &outer_object_id, - const rpc::Address &owner_address, const std::string &serialized_object_status) { + const ObjectID &object_id, + const ObjectID &outer_object_id, + const rpc::Address &owner_address, + const std::string &serialized_object_status) { // Add the object's owner to the local metadata in case it gets serialized // again. reference_counter_->AddBorrowedObject(object_id, outer_object_id, owner_address); @@ -812,8 +866,8 @@ void CoreWorker::RegisterOwnershipInfoAndResolveFuture( if (object_status.has_object() && !reference_counter_->OwnedByUs(object_id)) { // We already have the inlined object status, process it immediately. - future_resolver_->ProcessResolvedObject(object_id, owner_address, Status::OK(), - object_status); + future_resolver_->ProcessResolvedObject( + object_id, owner_address, Status::OK(), object_status); } else { // We will ask the owner about the object until the object is // created or we can no longer reach the owner. @@ -826,9 +880,13 @@ Status CoreWorker::Put(const RayObject &object, ObjectID *object_id) { *object_id = ObjectID::FromIndex(worker_context_.GetCurrentInternalTaskId(), worker_context_.GetNextPutIndex()); - reference_counter_->AddOwnedObject(*object_id, contained_object_ids, rpc_address_, - CurrentCallSite(), object.GetSize(), - /*is_reconstructable=*/false, /*add_local_ref=*/true, + reference_counter_->AddOwnedObject(*object_id, + contained_object_ids, + rpc_address_, + CurrentCallSite(), + object.GetSize(), + /*is_reconstructable=*/false, + /*add_local_ref=*/true, NodeID::FromBinary(rpc_address_.raylet_id())); auto status = Put(object, contained_object_ids, *object_id, /*pin_object=*/true); if (!status.ok()) { @@ -838,7 +896,8 @@ Status CoreWorker::Put(const RayObject &object, } Status CoreWorker::PutInLocalPlasmaStore(const RayObject &object, - const ObjectID &object_id, bool pin_object) { + const ObjectID &object_id, + bool pin_object) { bool object_exists; RAY_RETURN_NOT_OK(plasma_store_provider_->Put( object, object_id, /* owner_address = */ rpc_address_, &object_exists)); @@ -847,7 +906,8 @@ Status CoreWorker::PutInLocalPlasmaStore(const RayObject &object, // Tell the raylet to pin the object **after** it is created. RAY_LOG(DEBUG) << "Pinning put object " << object_id; local_raylet_client_->PinObjectIDs( - rpc_address_, {object_id}, + rpc_address_, + {object_id}, [this, object_id](const Status &status, const rpc::PinObjectIDsReply &reply) { // Only release the object once the raylet has responded to avoid the race // condition that the object could be evicted before the raylet pins it. @@ -866,7 +926,8 @@ Status CoreWorker::PutInLocalPlasmaStore(const RayObject &object, Status CoreWorker::Put(const RayObject &object, const std::vector &contained_object_ids, - const ObjectID &object_id, bool pin_object) { + const ObjectID &object_id, + bool pin_object) { RAY_RETURN_NOT_OK(WaitForActorRegistered(contained_object_ids)); if (options_.is_local_mode) { RAY_LOG(DEBUG) << "Put " << object_id << " in memory store"; @@ -877,10 +938,14 @@ Status CoreWorker::Put(const RayObject &object, } Status CoreWorker::CreateOwnedAndIncrementLocalRef( - const std::shared_ptr &metadata, const size_t data_size, - const std::vector &contained_object_ids, ObjectID *object_id, - std::shared_ptr *data, bool created_by_worker, - const std::unique_ptr &owner_address, bool inline_small_object) { + const std::shared_ptr &metadata, + const size_t data_size, + const std::vector &contained_object_ids, + ObjectID *object_id, + std::shared_ptr *data, + bool created_by_worker, + const std::unique_ptr &owner_address, + bool inline_small_object) { auto status = WaitForActorRegistered(contained_object_ids); if (!status.ok()) { return status; @@ -891,8 +956,11 @@ Status CoreWorker::CreateOwnedAndIncrementLocalRef( owner_address != nullptr ? *owner_address : rpc_address_; bool owned_by_us = real_owner_address.worker_id() == rpc_address_.worker_id(); if (owned_by_us) { - reference_counter_->AddOwnedObject(*object_id, contained_object_ids, rpc_address_, - CurrentCallSite(), data_size + metadata->Size(), + reference_counter_->AddOwnedObject(*object_id, + contained_object_ids, + rpc_address_, + CurrentCallSite(), + data_size + metadata->Size(), /*is_reconstructable=*/false, /*add_local_ref=*/true, NodeID::FromBinary(rpc_address_.raylet_id())); @@ -903,9 +971,11 @@ Status CoreWorker::CreateOwnedAndIncrementLocalRef( // by invoking `AddLocalReference` first. Note that in worker.py we set // skip_adding_local_ref=True to avoid double referencing the object. AddLocalReference(*object_id); - RAY_UNUSED(reference_counter_->AddBorrowedObject( - *object_id, ObjectID::Nil(), real_owner_address, - /*foreign_owner_already_monitoring=*/true)); + RAY_UNUSED( + reference_counter_->AddBorrowedObject(*object_id, + ObjectID::Nil(), + real_owner_address, + /*foreign_owner_already_monitoring=*/true)); // Remote call `AssignObjectOwner()`. rpc::AssignObjectOwnerRequest request; @@ -932,8 +1002,11 @@ Status CoreWorker::CreateOwnedAndIncrementLocalRef( *data = std::make_shared(data_size); } else { if (status.ok()) { - status = plasma_store_provider_->Create(metadata, data_size, *object_id, - /* owner_address = */ rpc_address_, data, + status = plasma_store_provider_->Create(metadata, + data_size, + *object_id, + /* owner_address = */ rpc_address_, + data, created_by_worker); } if (!status.ok()) { @@ -950,20 +1023,23 @@ Status CoreWorker::CreateOwnedAndIncrementLocalRef( } Status CoreWorker::CreateExisting(const std::shared_ptr &metadata, - const size_t data_size, const ObjectID &object_id, + const size_t data_size, + const ObjectID &object_id, const rpc::Address &owner_address, - std::shared_ptr *data, bool created_by_worker) { + std::shared_ptr *data, + bool created_by_worker) { if (options_.is_local_mode) { return Status::NotImplemented( "Creating an object with a pre-existing ObjectID is not supported in local " "mode"); } else { - return plasma_store_provider_->Create(metadata, data_size, object_id, owner_address, - data, created_by_worker); + return plasma_store_provider_->Create( + metadata, data_size, object_id, owner_address, data, created_by_worker); } } -Status CoreWorker::SealOwned(const ObjectID &object_id, bool pin_object, +Status CoreWorker::SealOwned(const ObjectID &object_id, + bool pin_object, const std::unique_ptr &owner_address) { auto status = SealExisting(object_id, pin_object, std::move(owner_address)); if (status.ok()) return status; @@ -976,14 +1052,16 @@ Status CoreWorker::SealOwned(const ObjectID &object_id, bool pin_object, return status; } -Status CoreWorker::SealExisting(const ObjectID &object_id, bool pin_object, +Status CoreWorker::SealExisting(const ObjectID &object_id, + bool pin_object, const std::unique_ptr &owner_address) { RAY_RETURN_NOT_OK(plasma_store_provider_->Seal(object_id)); if (pin_object) { // Tell the raylet to pin the object **after** it is created. RAY_LOG(DEBUG) << "Pinning sealed object " << object_id; local_raylet_client_->PinObjectIDs( - owner_address != nullptr ? *owner_address : rpc_address_, {object_id}, + owner_address != nullptr ? *owner_address : rpc_address_, + {object_id}, [this, object_id](const Status &status, const rpc::PinObjectIDsReply &reply) { // Only release the object once the raylet has responded to avoid the race // condition that the object could be evicted before the raylet pins it. @@ -1000,7 +1078,8 @@ Status CoreWorker::SealExisting(const ObjectID &object_id, bool pin_object, return Status::OK(); } -Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_ms, +Status CoreWorker::Get(const std::vector &ids, + const int64_t timeout_ms, std::vector> *results) { results->resize(ids.size(), nullptr); @@ -1012,8 +1091,8 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m auto start_time = current_time_ms(); if (!memory_object_ids.empty()) { - RAY_RETURN_NOT_OK(memory_store_->Get(memory_object_ids, timeout_ms, worker_context_, - &result_map, &got_exception)); + RAY_RETURN_NOT_OK(memory_store_->Get( + memory_object_ids, timeout_ms, worker_context_, &result_map, &got_exception)); } // Erase any objects that were promoted to plasma from the results. These get @@ -1037,8 +1116,10 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m timeout_ms - (current_time_ms() - start_time)); } RAY_LOG(DEBUG) << "Plasma GET timeout " << local_timeout_ms; - RAY_RETURN_NOT_OK(plasma_store_provider_->Get(plasma_object_ids, local_timeout_ms, - worker_context_, &result_map, + RAY_RETURN_NOT_OK(plasma_store_provider_->Get(plasma_object_ids, + local_timeout_ms, + worker_context_, + &result_map, &got_exception)); } @@ -1088,7 +1169,8 @@ Status CoreWorker::GetIfLocal(const std::vector &ids, return Status::OK(); } -Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object, +Status CoreWorker::Contains(const ObjectID &object_id, + bool *has_object, bool *is_in_plasma) { bool found = false; bool in_plasma = false; @@ -1116,9 +1198,12 @@ void RetryObjectInPlasmaErrors(std::shared_ptr &memory_st auto ready_iter = ready.find(mem_id); if (ready_iter != ready.end()) { std::vector> found; - RAY_CHECK_OK(memory_store->Get({mem_id}, /*num_objects=*/1, /*timeout=*/0, + RAY_CHECK_OK(memory_store->Get({mem_id}, + /*num_objects=*/1, + /*timeout=*/0, worker_context, - /*remote_after_get=*/false, &found)); + /*remote_after_get=*/false, + &found)); if (found.size() == 1 && found[0]->IsInPlasmaError()) { plasma_object_ids.insert(mem_id); ready.erase(ready_iter); @@ -1128,8 +1213,10 @@ void RetryObjectInPlasmaErrors(std::shared_ptr &memory_st } } -Status CoreWorker::Wait(const std::vector &ids, int num_objects, - int64_t timeout_ms, std::vector *results, +Status CoreWorker::Wait(const std::vector &ids, + int num_objects, + int64_t timeout_ms, + std::vector *results, bool fetch_local) { results->resize(ids.size(), false); @@ -1149,22 +1236,26 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, int64_t start_time = current_time_ms(); RAY_RETURN_NOT_OK(memory_store_->Wait( memory_object_ids, - std::min(static_cast(memory_object_ids.size()), num_objects), timeout_ms, - worker_context_, &ready)); + std::min(static_cast(memory_object_ids.size()), num_objects), + timeout_ms, + worker_context_, + &ready)); RAY_CHECK(static_cast(ready.size()) <= num_objects); if (timeout_ms > 0) { timeout_ms = std::max(0, static_cast(timeout_ms - (current_time_ms() - start_time))); } if (fetch_local) { - RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids, - plasma_object_ids, ready); + RetryObjectInPlasmaErrors( + memory_store_, worker_context_, memory_object_ids, plasma_object_ids, ready); if (static_cast(ready.size()) < num_objects && plasma_object_ids.size() > 0) { RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( plasma_object_ids, std::min(static_cast(plasma_object_ids.size()), num_objects - static_cast(ready.size())), - timeout_ms, worker_context_, &ready)); + timeout_ms, + worker_context_, + &ready)); } } RAY_CHECK(static_cast(ready.size()) <= num_objects); @@ -1198,7 +1289,8 @@ Status CoreWorker::Delete(const std::vector &object_ids, bool local_on } Status CoreWorker::GetLocationFromOwner( - const std::vector &object_ids, int64_t timeout_ms, + const std::vector &object_ids, + int64_t timeout_ms, std::vector> *results) { results->resize(object_ids.size()); if (object_ids.empty()) { @@ -1282,8 +1374,10 @@ TaskID CoreWorker::GetCallerId() const { return caller_id; } -Status CoreWorker::PushError(const JobID &job_id, const std::string &type, - const std::string &error_message, double timestamp) { +Status CoreWorker::PushError(const JobID &job_id, + const std::string &type, + const std::string &error_message, + double timestamp) { if (options_.is_local_mode) { RAY_LOG(ERROR) << "Pushed Error with JobID: " << job_id << " of type: " << type << " with message: " << error_message << " at time: " << timestamp; @@ -1305,8 +1399,8 @@ void CoreWorker::SpillOwnedObject(const ObjectID &object_id, bool owned_by_us = false; NodeID pinned_at; bool spilled = false; - RAY_CHECK(reference_counter_->IsPlasmaObjectPinnedOrSpilled(object_id, &owned_by_us, - &pinned_at, &spilled)); + RAY_CHECK(reference_counter_->IsPlasmaObjectPinnedOrSpilled( + object_id, &owned_by_us, &pinned_at, &spilled)); RAY_CHECK(owned_by_us); if (spilled) { // The object has already been spilled. @@ -1321,13 +1415,14 @@ void CoreWorker::SpillOwnedObject(const ObjectID &object_id, // Ask the raylet to spill the object. RAY_LOG(DEBUG) << "Sending spill request to raylet for object " << object_id; - auto raylet_client = - std::make_shared(rpc::NodeManagerWorkerClient::make( - node->node_manager_address(), node->node_manager_port(), - *client_call_manager_)); + auto raylet_client = std::make_shared( + rpc::NodeManagerWorkerClient::make(node->node_manager_address(), + node->node_manager_port(), + *client_call_manager_)); raylet_client->RequestObjectSpillage( - object_id, [object_id, callback](const Status &status, - const rpc::RequestObjectSpillageReply &reply) { + object_id, + [object_id, callback](const Status &status, + const rpc::RequestObjectSpillageReply &reply) { if (!status.ok() || !reply.success()) { RAY_LOG(ERROR) << "Failed to spill object " << object_id << ", raylet unreachable or object could not be spilled."; @@ -1510,23 +1605,42 @@ std::shared_ptr CoreWorker::OverrideTaskOrActorRuntimeEnvIn } void CoreWorker::BuildCommonTaskSpec( - TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, - const std::string &name, const TaskID ¤t_task_id, uint64_t task_index, - const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, - const std::vector> &args, uint64_t num_returns, + TaskSpecBuilder &builder, + const JobID &job_id, + const TaskID &task_id, + const std::string &name, + const TaskID ¤t_task_id, + uint64_t task_index, + const TaskID &caller_id, + const rpc::Address &address, + const RayFunction &function, + const std::vector> &args, + uint64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const std::string &debugger_breakpoint, int64_t depth, + const std::string &debugger_breakpoint, + int64_t depth, const std::string &serialized_runtime_env_info, const std::string &concurrency_group_name) { // Build common task spec. auto override_runtime_env_info = OverrideTaskOrActorRuntimeEnvInfo(serialized_runtime_env_info); - builder.SetCommonTaskSpec( - task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, - current_task_id, task_index, caller_id, address, num_returns, required_resources, - required_placement_resources, debugger_breakpoint, depth, override_runtime_env_info, - concurrency_group_name); + builder.SetCommonTaskSpec(task_id, + name, + function.GetLanguage(), + function.GetFunctionDescriptor(), + job_id, + current_task_id, + task_index, + caller_id, + address, + num_returns, + required_resources, + required_placement_resources, + debugger_breakpoint, + depth, + override_runtime_env_info, + concurrency_group_name); // Set task arguments. for (const auto &arg : args) { builder.AddArg(*arg); @@ -1534,8 +1648,11 @@ void CoreWorker::BuildCommonTaskSpec( } std::vector CoreWorker::SubmitTask( - const RayFunction &function, const std::vector> &args, - const TaskOptions &task_options, int max_retries, bool retry_exceptions, + const RayFunction &function, + const std::vector> &args, + const TaskOptions &task_options, + int max_retries, + bool retry_exceptions, const rpc::SchedulingStrategy &scheduling_strategy, const std::string &debugger_breakpoint) { RAY_CHECK(scheduling_strategy.scheduling_strategy_case() != @@ -1543,9 +1660,9 @@ std::vector CoreWorker::SubmitTask( TaskSpecBuilder builder; const auto next_task_index = worker_context_.GetNextTaskIndex(); - const auto task_id = - TaskID::ForNormalTask(worker_context_.GetCurrentJobID(), - worker_context_.GetCurrentInternalTaskId(), next_task_index); + const auto task_id = TaskID::ForNormalTask(worker_context_.GetCurrentJobID(), + worker_context_.GetCurrentInternalTaskId(), + next_task_index); auto constrained_resources = AddPlacementGroupConstraint(task_options.resources, scheduling_strategy); @@ -1555,11 +1672,22 @@ std::vector CoreWorker::SubmitTask( : task_options.name; int64_t depth = worker_context_.GetTaskDepth() + 1; // TODO(ekl) offload task building onto a thread pool for performance - BuildCommonTaskSpec(builder, worker_context_.GetCurrentJobID(), task_id, task_name, - worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), - rpc_address_, function, args, task_options.num_returns, - constrained_resources, required_resources, debugger_breakpoint, - depth, task_options.serialized_runtime_env_info); + BuildCommonTaskSpec(builder, + worker_context_.GetCurrentJobID(), + task_id, + task_name, + worker_context_.GetCurrentTaskID(), + next_task_index, + GetCallerId(), + rpc_address_, + function, + args, + task_options.num_returns, + constrained_resources, + required_resources, + debugger_breakpoint, + depth, + task_options.serialized_runtime_env_info); builder.SetNormalTaskSpec(max_retries, retry_exceptions, scheduling_strategy); TaskSpecification task_spec = builder.Build(); RAY_LOG(DEBUG) << "Submitting normal task " << task_spec.DebugString(); @@ -1567,8 +1695,8 @@ std::vector CoreWorker::SubmitTask( if (options_.is_local_mode) { returned_refs = ExecuteTaskLocalMode(task_spec); } else { - returned_refs = task_manager_->AddPendingTask(task_spec.CallerAddress(), task_spec, - CurrentCallSite(), max_retries); + returned_refs = task_manager_->AddPendingTask( + task_spec.CallerAddress(), task_spec, CurrentCallSite(), max_retries); io_service_.post( [this, task_spec]() { RAY_UNUSED(direct_task_submitter_->SubmitTask(task_spec)); @@ -1603,9 +1731,9 @@ Status CoreWorker::CreateActor(const RayFunction &function, } const auto next_task_index = worker_context_.GetNextTaskIndex(); - const ActorID actor_id = - ActorID::Of(worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), - next_task_index); + const ActorID actor_id = ActorID::Of(worker_context_.GetCurrentJobID(), + worker_context_.GetCurrentTaskID(), + next_task_index); const TaskID actor_creation_task_id = TaskID::ForActorCreationTask(actor_id); const JobID job_id = worker_context_.GetCurrentJobID(); // Propagate existing environment variable overrides, but override them with any new @@ -1623,10 +1751,21 @@ Status CoreWorker::CreateActor(const RayFunction &function, ? function.GetFunctionDescriptor()->DefaultTaskName() : actor_name + ":" + function.GetFunctionDescriptor()->CallString(); int64_t depth = worker_context_.GetTaskDepth() + 1; - BuildCommonTaskSpec(builder, job_id, actor_creation_task_id, task_name, - worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), - rpc_address_, function, args, 1, new_resource, - new_placement_resources, "" /* debugger_breakpoint */, depth, + BuildCommonTaskSpec(builder, + job_id, + actor_creation_task_id, + task_name, + worker_context_.GetCurrentTaskID(), + next_task_index, + GetCallerId(), + rpc_address_, + function, + args, + 1, + new_resource, + new_placement_resources, + "" /* debugger_breakpoint */, + depth, actor_creation_options.serialized_runtime_env_info); // If the namespace is not specified, get it from the job. @@ -1634,26 +1773,40 @@ Status CoreWorker::CreateActor(const RayFunction &function, ? job_config_->ray_namespace() : actor_creation_options.ray_namespace); auto actor_handle = std::make_unique( - actor_id, GetCallerId(), rpc_address_, job_id, + actor_id, + GetCallerId(), + rpc_address_, + job_id, /*actor_cursor=*/ObjectID::FromIndex(actor_creation_task_id, 1), - function.GetLanguage(), function.GetFunctionDescriptor(), extension_data, - actor_creation_options.max_task_retries, actor_name, ray_namespace, + function.GetLanguage(), + function.GetFunctionDescriptor(), + extension_data, + actor_creation_options.max_task_retries, + actor_name, + ray_namespace, actor_creation_options.max_pending_calls, actor_creation_options.execute_out_of_order); std::string serialized_actor_handle; actor_handle->Serialize(&serialized_actor_handle); - builder.SetActorCreationTaskSpec( - actor_id, serialized_actor_handle, actor_creation_options.scheduling_strategy, - actor_creation_options.max_restarts, actor_creation_options.max_task_retries, - actor_creation_options.dynamic_worker_options, - actor_creation_options.max_concurrency, is_detached, actor_name, ray_namespace, - actor_creation_options.is_asyncio, actor_creation_options.concurrency_groups, - extension_data, actor_creation_options.execute_out_of_order); + builder.SetActorCreationTaskSpec(actor_id, + serialized_actor_handle, + actor_creation_options.scheduling_strategy, + actor_creation_options.max_restarts, + actor_creation_options.max_task_retries, + actor_creation_options.dynamic_worker_options, + actor_creation_options.max_concurrency, + is_detached, + actor_name, + ray_namespace, + actor_creation_options.is_asyncio, + actor_creation_options.concurrency_groups, + extension_data, + actor_creation_options.execute_out_of_order); // Add the actor handle before we submit the actor creation task, since the // actor handle must be in scope by the time the GCS sends the // WaitForActorOutOfScopeRequest. - RAY_CHECK(actor_manager_->AddNewActorHandle(std::move(actor_handle), CurrentCallSite(), - rpc_address_, is_detached)) + RAY_CHECK(actor_manager_->AddNewActorHandle( + std::move(actor_handle), CurrentCallSite(), rpc_address_, is_detached)) << "Actor " << actor_id << " already exists"; *return_actor_id = actor_id; TaskSpecification task_spec = builder.Build(); @@ -1677,8 +1830,8 @@ Status CoreWorker::CreateActor(const RayFunction &function, max_retries = std::max((int64_t)RayConfig::instance().actor_creation_min_retries(), actor_creation_options.max_restarts); } - task_manager_->AddPendingTask(rpc_address_, task_spec, CurrentCallSite(), - max_retries); + task_manager_->AddPendingTask( + rpc_address_, task_spec, CurrentCallSite(), max_retries); if (actor_name.empty()) { io_service_.post( @@ -1729,11 +1882,14 @@ Status CoreWorker::CreatePlacementGroup( } const PlacementGroupID placement_group_id = PlacementGroupID::FromRandom(); PlacementGroupSpecBuilder builder; - builder.SetPlacementGroupSpec( - placement_group_id, placement_group_creation_options.name, - placement_group_creation_options.bundles, placement_group_creation_options.strategy, - placement_group_creation_options.is_detached, worker_context_.GetCurrentJobID(), - worker_context_.GetCurrentActorID(), worker_context_.CurrentActorDetached()); + builder.SetPlacementGroupSpec(placement_group_id, + placement_group_creation_options.name, + placement_group_creation_options.bundles, + placement_group_creation_options.strategy, + placement_group_creation_options.is_detached, + worker_context_.GetCurrentJobID(), + worker_context_.GetCurrentActorID(), + worker_context_.CurrentActorDetached()); PlacementGroupSpecification placement_group_spec = builder.Build(); *return_placement_group_id = placement_group_id; RAY_LOG(INFO) << "Submitting Placement Group creation to GCS: " << placement_group_id; @@ -1782,8 +1938,10 @@ Status CoreWorker::WaitPlacementGroupReady(const PlacementGroupID &placement_gro } std::optional> CoreWorker::SubmitActorTask( - const ActorID &actor_id, const RayFunction &function, - const std::vector> &args, const TaskOptions &task_options) { + const ActorID &actor_id, + const RayFunction &function, + const std::vector> &args, + const TaskOptions &task_options) { absl::ReleasableMutexLock lock(&actor_task_mutex_); /// Check whether backpressure may happen at the very beginning of submitting a task. if (direct_actor_submitter_->PendingTasksFull(actor_id)) { @@ -1800,9 +1958,11 @@ std::optional> CoreWorker::SubmitActorTask( // Build common task spec. TaskSpecBuilder builder; const auto next_task_index = worker_context_.GetNextTaskIndex(); - const TaskID actor_task_id = TaskID::ForActorTask( - worker_context_.GetCurrentJobID(), worker_context_.GetCurrentInternalTaskId(), - next_task_index, actor_handle->GetActorID()); + const TaskID actor_task_id = + TaskID::ForActorTask(worker_context_.GetCurrentJobID(), + worker_context_.GetCurrentInternalTaskId(), + next_task_index, + actor_handle->GetActorID()); const std::unordered_map required_resources; const auto task_name = task_options.name.empty() ? function.GetFunctionDescriptor()->DefaultTaskName() @@ -1811,12 +1971,22 @@ std::optional> CoreWorker::SubmitActorTask( // Depth shouldn't matter for an actor task, but for consistency it should be // the same as the actor creation task's depth. int64_t depth = worker_context_.GetTaskDepth(); - BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id, task_name, - worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), - rpc_address_, function, args, num_returns, task_options.resources, - required_resources, "", /* debugger_breakpoint */ - depth, /*depth*/ - "{}", /* serialized_runtime_env_info */ + BuildCommonTaskSpec(builder, + actor_handle->CreationJobID(), + actor_task_id, + task_name, + worker_context_.GetCurrentTaskID(), + next_task_index, + GetCallerId(), + rpc_address_, + function, + args, + num_returns, + task_options.resources, + required_resources, + "", /* debugger_breakpoint */ + depth, /*depth*/ + "{}", /* serialized_runtime_env_info */ task_options.concurrency_group_name); // NOTE: placement_group_capture_child_tasks and runtime_env will // be ignored in the actor because we should always follow the actor's option. @@ -1844,7 +2014,8 @@ std::optional> CoreWorker::SubmitActorTask( return {std::move(returned_refs)}; } -Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill, +Status CoreWorker::CancelTask(const ObjectID &object_id, + bool force_kill, bool recursive) { if (actor_manager_->CheckActorHandleExists(object_id.TaskId().ActorId())) { return Status::Invalid("Actor task cancellation is not supported."); @@ -1854,8 +2025,8 @@ Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill, return Status::Invalid("No owner found for object."); } if (obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) { - return direct_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill, - recursive); + return direct_task_submitter_->CancelRemoteTask( + object_id, obj_addr, force_kill, recursive); } auto task_spec = task_manager_->GetTaskSpec(object_id.TaskId()); @@ -1894,8 +2065,8 @@ Status CoreWorker::KillActor(const ActorID &actor_id, bool force_kill, bool no_r [this, p = &p, actor_id, force_kill, no_restart]() { auto cb = [this, p, actor_id, force_kill, no_restart](Status status) mutable { if (status.ok()) { - RAY_CHECK_OK(gcs_client_->Actors().AsyncKillActor(actor_id, force_kill, - no_restart, nullptr)); + RAY_CHECK_OK(gcs_client_->Actors().AsyncKillActor( + actor_id, force_kill, no_restart, nullptr)); } p->set_value(std::move(status)); }; @@ -1936,11 +2107,12 @@ void CoreWorker::RemoveActorHandleReference(const ActorID &actor_id) { ActorID CoreWorker::DeserializeAndRegisterActorHandle(const std::string &serialized, const ObjectID &outer_object_id) { std::unique_ptr actor_handle(new ActorHandle(serialized)); - return actor_manager_->RegisterActorHandle(std::move(actor_handle), outer_object_id, - CurrentCallSite(), rpc_address_); + return actor_manager_->RegisterActorHandle( + std::move(actor_handle), outer_object_id, CurrentCallSite(), rpc_address_); } -Status CoreWorker::SerializeActorHandle(const ActorID &actor_id, std::string *output, +Status CoreWorker::SerializeActorHandle(const ActorID &actor_id, + std::string *output, ObjectID *actor_handle_id) const { auto actor_handle = actor_manager_->GetActorHandle(actor_id); actor_handle->Serialize(output); @@ -1961,8 +2133,10 @@ std::pair, Status> CoreWorker::GetNamedActorH } return actor_manager_->GetNamedActorHandle( - name, ray_namespace.empty() ? job_config_->ray_namespace() : ray_namespace, - CurrentCallSite(), rpc_address_); + name, + ray_namespace.empty() ? job_config_->ray_namespace() : ray_namespace, + CurrentCallSite(), + rpc_address_); } std::pair>, Status> @@ -2006,7 +2180,8 @@ std::pair>, Status> CoreWorker::ListNamedActorsLocalMode() { std::vector> actors; for (auto it = local_mode_named_actor_registry_.begin(); - it != local_mode_named_actor_registry_.end(); it++) { + it != local_mode_named_actor_registry_.end(); + it++) { actors.push_back(std::make_pair(/*namespace=*/"", it->first)); } return std::make_pair(actors, Status::OK()); @@ -2045,8 +2220,8 @@ Status CoreWorker::AllocateReturnObject(const ObjectID &object_id, // Mark this object as containing other object IDs. The ref counter will // keep the inner IDs in scope until the outer one is out of scope. if (!contained_object_ids.empty() && !options_.is_local_mode) { - reference_counter_->AddNestedObjectIds(object_id, contained_object_ids, - owner_address); + reference_counter_->AddNestedObjectIds( + object_id, contained_object_ids, owner_address); } // Allocate a buffer for the return object. @@ -2058,7 +2233,10 @@ Status CoreWorker::AllocateReturnObject(const ObjectID &object_id, data_buffer = std::make_shared(data_size); *task_output_inlined_bytes += static_cast(data_size); } else { - RAY_RETURN_NOT_OK(CreateExisting(metadata, data_size, object_id, owner_address, + RAY_RETURN_NOT_OK(CreateExisting(metadata, + data_size, + object_id, + owner_address, &data_buffer, /*created_by_worker=*/true)); object_already_exists = !data_buffer; @@ -2130,8 +2308,10 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, std::unique_ptr self_actor_handle( new ActorHandle(task_spec.GetSerializedActorHandle())); // Register the handle to the current actor itself. - actor_manager_->RegisterActorHandle(std::move(self_actor_handle), ObjectID::Nil(), - CurrentCallSite(), rpc_address_, + actor_manager_->RegisterActorHandle(std::move(self_actor_handle), + ObjectID::Nil(), + CurrentCallSite(), + rpc_address_, /*is_self=*/true); } RAY_LOG(INFO) << "Creating actor: " << task_spec.ActorCreationId(); @@ -2156,11 +2336,19 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, } status = options_.task_execution_callback( - task_type, task_spec.GetName(), func, - task_spec.GetRequiredResources().GetResourceUnorderedMap(), args, arg_refs, - return_ids, task_spec.GetDebuggerBreakpoint(), return_objects, - creation_task_exception_pb_bytes, is_application_level_error, - defined_concurrency_groups, name_of_concurrency_group_to_execute); + task_type, + task_spec.GetName(), + func, + task_spec.GetRequiredResources().GetResourceUnorderedMap(), + args, + arg_refs, + return_ids, + task_spec.GetDebuggerBreakpoint(), + return_objects, + creation_task_exception_pb_bytes, + is_application_level_error, + defined_concurrency_groups, + name_of_concurrency_group_to_execute); // Get the reference counts for any IDs that we borrowed during this task, // remove the local reference for these IDs, and return the ref count info to @@ -2250,8 +2438,8 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, reference_counter_->AddLocalReference(return_id, ""); reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); - auto status = plasma_store_provider_->Get({return_id}, 0, worker_context_, &result_map, - &got_exception); + auto status = plasma_store_provider_->Get( + {return_id}, 0, worker_context_, &result_map, &got_exception); // Remove the temporary ref. RemoveLocalReference(return_id); @@ -2266,7 +2454,8 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, // if the raylet fails. We expect the owner of the object to handle that // case (e.g., by detecting the raylet failure and storing an error). local_raylet_client_->PinObjectIDs( - owner_address, {return_id}, + owner_address, + {return_id}, [return_id, pinned_return_object](const Status &status, const rpc::PinObjectIDsReply &reply) { if (!status.ok()) { @@ -2300,8 +2489,10 @@ std::vector CoreWorker::ExecuteTaskLocalMode( for (size_t i = 0; i < num_returns; i++) { if (!task_spec.IsActorCreationTask()) { reference_counter_->AddOwnedObject(task_spec.ReturnId(i), - /*inner_ids=*/{}, rpc_address_, - CurrentCallSite(), -1, + /*inner_ids=*/{}, + rpc_address_, + CurrentCallSite(), + -1, /*is_reconstructable=*/false, /*add_local_ref=*/true); } @@ -2313,7 +2504,10 @@ std::vector CoreWorker::ExecuteTaskLocalMode( auto old_id = GetActorId(); SetActorId(actor_id); bool is_application_level_error; - RAY_UNUSED(ExecuteTask(task_spec, resource_ids, &return_objects, &borrowed_refs, + RAY_UNUSED(ExecuteTask(task_spec, + resource_ids, + &return_objects, + &borrowed_refs, &is_application_level_error)); SetActorId(old_id); return returned_refs; @@ -2355,8 +2549,8 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task, reference_counter_->AddLocalReference(arg_id, task.CallSiteString()); // Attach the argument's owner's address. This is needed to retrieve the // value from plasma. - reference_counter_->AddBorrowedObject(arg_id, ObjectID::Nil(), - task.ArgRef(i).owner_address()); + reference_counter_->AddBorrowedObject( + arg_id, ObjectID::Nil(), task.ArgRef(i).owner_address()); borrowed_ids->push_back(arg_id); } else { // A pass-by-value argument. @@ -2400,8 +2594,8 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task, RAY_RETURN_NOT_OK( memory_store_->Get(by_ref_ids, -1, worker_context_, &result_map, &got_exception)); } else { - RAY_RETURN_NOT_OK(plasma_store_provider_->Get(by_ref_ids, -1, worker_context_, - &result_map, &got_exception)); + RAY_RETURN_NOT_OK(plasma_store_provider_->Get( + by_ref_ids, -1, worker_context_, &result_map, &got_exception)); } for (const auto &it : result_map) { for (size_t idx : by_ref_indices[it.first]) { @@ -2507,15 +2701,16 @@ void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &reques // Send the reply once the value has become available. The value is // guaranteed to become available eventually because we own the object and // its ref count is > 0. - memory_store_->GetAsync(object_id, [this, object_id, reply, send_reply_callback, - is_freed](std::shared_ptr obj) { - if (is_freed) { - reply->set_status(rpc::GetObjectStatusReply::FREED); - } else { - PopulateObjectStatus(object_id, obj, reply); - } - send_reply_callback(Status::OK(), nullptr, nullptr); - }); + memory_store_->GetAsync(object_id, + [this, object_id, reply, send_reply_callback, is_freed]( + std::shared_ptr obj) { + if (is_freed) { + reply->set_status(rpc::GetObjectStatusReply::FREED); + } else { + PopulateObjectStatus(object_id, obj, reply); + } + send_reply_callback(Status::OK(), nullptr, nullptr); + }); } RemoveLocalReference(object_id); @@ -2556,7 +2751,8 @@ void CoreWorker::PopulateObjectStatus(const ObjectID &object_id, void CoreWorker::HandleWaitForActorOutOfScope( const rpc::WaitForActorOutOfScopeRequest &request, - rpc::WaitForActorOutOfScopeReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::WaitForActorOutOfScopeReply *reply, + rpc::SendReplyCallback send_reply_callback) { // Currently WaitForActorOutOfScope is only used when GCS actor service is enabled. if (HandleWrongRecipient(WorkerID::FromBinary(request.intended_worker_id()), send_reply_callback)) { @@ -2647,11 +2843,13 @@ void CoreWorker::ProcessPubsubCommands(const Commands &commands, const NodeID &subscriber_id) { for (const auto &command : commands) { if (command.has_unsubscribe_message()) { - object_info_publisher_->UnregisterSubscription(command.channel_type(), - subscriber_id, command.key_id()); + object_info_publisher_->UnregisterSubscription( + command.channel_type(), subscriber_id, command.key_id()); } else if (command.has_subscribe_message()) { - ProcessSubscribeMessage(command.subscribe_message(), command.channel_type(), - command.key_id(), subscriber_id); + ProcessSubscribeMessage(command.subscribe_message(), + command.channel_type(), + command.key_id(), + subscriber_id); } else { RAY_LOG(FATAL) << "Invalid command has received, " << static_cast(command.command_message_one_of_case()) @@ -2667,8 +2865,8 @@ void CoreWorker::HandlePubsubLongPolling(const rpc::PubsubLongPollingRequest &re rpc::SendReplyCallback send_reply_callback) { const auto subscriber_id = NodeID::FromBinary(request.subscriber_id()); RAY_LOG(DEBUG) << "Got a long polling request from a node " << subscriber_id; - object_info_publisher_->ConnectToSubscriber(request, reply, - std::move(send_reply_callback)); + object_info_publisher_->ConnectToSubscriber( + request, reply, std::move(send_reply_callback)); } void CoreWorker::HandlePubsubCommandBatch(const rpc::PubsubCommandBatchRequest &request, @@ -2704,7 +2902,8 @@ void CoreWorker::HandleUpdateObjectLocationBatch( } } - send_reply_callback(Status::OK(), /*success_callback_on_reply*/ nullptr, + send_reply_callback(Status::OK(), + /*success_callback_on_reply*/ nullptr, /*failure_callback_on_reply*/ nullptr); } @@ -2784,15 +2983,16 @@ void CoreWorker::ProcessSubscribeForRefRemoved( const auto owner_address = message.reference().owner_address(); ObjectID contained_in_id = ObjectID::FromBinary(message.contained_in_id()); - reference_counter_->SetRefRemovedCallback(object_id, contained_in_id, owner_address, - ref_removed_callback); + reference_counter_->SetRefRemovedCallback( + object_id, contained_in_id, owner_address, ref_removed_callback); } void CoreWorker::HandleRemoteCancelTask(const rpc::RemoteCancelTaskRequest &request, rpc::RemoteCancelTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { auto status = CancelTask(ObjectID::FromBinary(request.remote_object_id()), - request.force_kill(), request.recursive()); + request.force_kill(), + request.recursive()); send_reply_callback(status, nullptr, nullptr); } @@ -2936,8 +3136,8 @@ void CoreWorker::HandleLocalGC(const rpc::LocalGCRequest &request, options_.gc_collect(); send_reply_callback(Status::OK(), nullptr, nullptr); } else { - send_reply_callback(Status::NotImplemented("GC callback not defined"), nullptr, - nullptr); + send_reply_callback( + Status::NotImplemented("GC callback not defined"), nullptr, nullptr); } } @@ -2953,8 +3153,8 @@ void CoreWorker::HandleSpillObjects(const rpc::SpillObjectsRequest &request, } send_reply_callback(Status::OK(), nullptr, nullptr); } else { - send_reply_callback(Status::NotImplemented("Spill objects callback not defined"), - nullptr, nullptr); + send_reply_callback( + Status::NotImplemented("Spill objects callback not defined"), nullptr, nullptr); } } @@ -2978,7 +3178,8 @@ void CoreWorker::HandleAddSpilledUrl(const rpc::AddSpilledUrlRequest &request, void CoreWorker::HandleRestoreSpilledObjects( const rpc::RestoreSpilledObjectsRequest &request, - rpc::RestoreSpilledObjectsReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::RestoreSpilledObjectsReply *reply, + rpc::SendReplyCallback send_reply_callback) { if (options_.restore_spilled_objects != nullptr) { // Get a list of object ids. std::vector object_refs_to_restore; @@ -3000,14 +3201,16 @@ void CoreWorker::HandleRestoreSpilledObjects( send_reply_callback(Status::OK(), nullptr, nullptr); } else { send_reply_callback( - Status::NotImplemented("Restore spilled objects callback not defined"), nullptr, + Status::NotImplemented("Restore spilled objects callback not defined"), + nullptr, nullptr); } } void CoreWorker::HandleDeleteSpilledObjects( const rpc::DeleteSpilledObjectsRequest &request, - rpc::DeleteSpilledObjectsReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::DeleteSpilledObjectsReply *reply, + rpc::SendReplyCallback send_reply_callback) { if (options_.delete_spilled_objects != nullptr) { std::vector spilled_objects_url; spilled_objects_url.reserve(request.spilled_objects_url_size()); @@ -3018,12 +3221,14 @@ void CoreWorker::HandleDeleteSpilledObjects( send_reply_callback(Status::OK(), nullptr, nullptr); } else { send_reply_callback( - Status::NotImplemented("Delete spilled objects callback not defined"), nullptr, + Status::NotImplemented("Delete spilled objects callback not defined"), + nullptr, nullptr); } } -void CoreWorker::HandleExit(const rpc::ExitRequest &request, rpc::ExitReply *reply, +void CoreWorker::HandleExit(const rpc::ExitRequest &request, + rpc::ExitReply *reply, rpc::SendReplyCallback send_reply_callback) { bool own_objects = reference_counter_->OwnObjects(); int64_t pins_in_flight = local_raylet_client_->GetPinsInFlight(); @@ -3056,7 +3261,11 @@ void CoreWorker::HandleAssignObjectOwner(const rpc::AssignObjectOwnerRequest &re contained_object_ids.push_back(ObjectID::FromBinary(id_binary)); } reference_counter_->AddOwnedObject( - object_id, contained_object_ids, rpc_address_, call_site, request.object_size(), + object_id, + contained_object_ids, + rpc_address_, + call_site, + request.object_size(), /*is_reconstructable=*/false, /*add_local_ref=*/false, /*pinned_at_raylet_id=*/NodeID::FromBinary(borrower_address.raylet_id())); @@ -3071,24 +3280,30 @@ void CoreWorker::YieldCurrentFiber(FiberEvent &event) { event.Wait(); } -void CoreWorker::GetAsync(const ObjectID &object_id, SetResultCallback success_callback, +void CoreWorker::GetAsync(const ObjectID &object_id, + SetResultCallback success_callback, void *python_future) { - auto fallback_callback = - std::bind(&CoreWorker::PlasmaCallback, this, success_callback, - std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); + auto fallback_callback = std::bind(&CoreWorker::PlasmaCallback, + this, + success_callback, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3); - memory_store_->GetAsync(object_id, [python_future, success_callback, fallback_callback, - object_id](std::shared_ptr ray_object) { - if (ray_object->IsInPlasmaError()) { - fallback_callback(ray_object, object_id, python_future); - } else { - success_callback(ray_object, object_id, python_future); - } - }); + memory_store_->GetAsync(object_id, + [python_future, success_callback, fallback_callback, object_id]( + std::shared_ptr ray_object) { + if (ray_object->IsInPlasmaError()) { + fallback_callback(ray_object, object_id, python_future); + } else { + success_callback(ray_object, object_id, python_future); + } + }); } void CoreWorker::PlasmaCallback(SetResultCallback success, - std::shared_ptr ray_object, ObjectID object_id, + std::shared_ptr ray_object, + ObjectID object_id, void *py_future) { RAY_CHECK(ray_object->IsInPlasmaError()); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index d12a41b89..8a2f5667d 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -213,7 +213,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[out] owner_address The address of the object's owner. This should /// be appended to the serialized object ID. /// \param[out] serialized_object_status The serialized object status protobuf. - void GetOwnershipInfo(const ObjectID &object_id, rpc::Address *owner_address, + void GetOwnershipInfo(const ObjectID &object_id, + rpc::Address *owner_address, std::string *serialized_object_status); /// Add a reference to an ObjectID that was deserialized by the language @@ -245,7 +246,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] contained_object_ids The IDs serialized in this object. /// \param[out] object_id Generated ID of the object. /// \return Status. - Status Put(const RayObject &object, const std::vector &contained_object_ids, + Status Put(const RayObject &object, + const std::vector &contained_object_ids, ObjectID *object_id); /// Put an object with specified ID into object store. @@ -255,8 +257,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] object_id Object ID specified by the user. /// \param[in] pin_object Whether or not to tell the raylet to pin this object. /// \return Status. - Status Put(const RayObject &object, const std::vector &contained_object_ids, - const ObjectID &object_id, bool pin_object = false); + Status Put(const RayObject &object, + const std::vector &contained_object_ids, + const ObjectID &object_id, + bool pin_object = false); /// Create and return a buffer in the object store that can be directly written /// into. After writing to the buffer, the caller must call `SealOwned()` to @@ -281,9 +285,12 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// small. /// \return Status. Status CreateOwnedAndIncrementLocalRef( - const std::shared_ptr &metadata, const size_t data_size, - const std::vector &contained_object_ids, ObjectID *object_id, - std::shared_ptr *data, bool created_by_worker, + const std::shared_ptr &metadata, + const size_t data_size, + const std::vector &contained_object_ids, + ObjectID *object_id, + std::shared_ptr *data, + bool created_by_worker, const std::unique_ptr &owner_address = nullptr, bool inline_small_object = true); @@ -299,9 +306,12 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] owner_address The address of the object's owner. /// \param[out] data Buffer for the user to write the object into. /// \return Status. - Status CreateExisting(const std::shared_ptr &metadata, const size_t data_size, - const ObjectID &object_id, const rpc::Address &owner_address, - std::shared_ptr *data, bool created_by_worker); + Status CreateExisting(const std::shared_ptr &metadata, + const size_t data_size, + const ObjectID &object_id, + const rpc::Address &owner_address, + std::shared_ptr *data, + bool created_by_worker); /// Finalize placing an object into the object store. This should be called after /// a corresponding `CreateOwned()` call and then writing into the returned buffer. @@ -315,7 +325,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] The address of object's owner. If not provided, /// defaults to this worker. /// \return Status. - Status SealOwned(const ObjectID &object_id, bool pin_object, + Status SealOwned(const ObjectID &object_id, + bool pin_object, const std::unique_ptr &owner_address = nullptr); /// Finalize placing an object into the object store. This should be called after @@ -326,7 +337,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] owner_address Address of the owner of the object who will be contacted by /// the raylet if the object is pinned. If not provided, defaults to this worker. /// \return Status. - Status SealExisting(const ObjectID &object_id, bool pin_object, + Status SealExisting(const ObjectID &object_id, + bool pin_object, const std::unique_ptr &owner_address = nullptr); /// Get a list of objects from the object store. Objects that failed to be retrieved @@ -336,7 +348,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. /// \param[out] results Result list of objects data. /// \return Status. - Status Get(const std::vector &ids, const int64_t timeout_ms, + Status Get(const std::vector &ids, + const int64_t timeout_ms, std::vector> *results); /// Get objects directly from the local plasma store, without waiting for the @@ -359,7 +372,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[out] has_object Whether or not the object is present. /// \param[out] is_in_plasma Whether or not the object is in Plasma. /// \return Status. - Status Contains(const ObjectID &object_id, bool *has_object, + Status Contains(const ObjectID &object_id, + bool *has_object, bool *is_in_plasma = nullptr); /// Wait for a list of objects to appear in the object store. @@ -373,8 +387,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. /// \param[out] results A bitset that indicates each object has appeared or not. /// \return Status. - Status Wait(const std::vector &object_ids, const int num_objects, - const int64_t timeout_ms, std::vector *results, bool fetch_local); + Status Wait(const std::vector &object_ids, + const int num_objects, + const int64_t timeout_ms, + std::vector *results, + bool fetch_local); /// Delete a list of objects from the plasma object store. /// @@ -391,7 +408,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. /// \param[out] results Result list of object locations. /// \return Status. - Status GetLocationFromOwner(const std::vector &object_ids, int64_t timeout_ms, + Status GetLocationFromOwner(const std::vector &object_ids, + int64_t timeout_ms, std::vector> *results); /// Trigger garbage collection on each worker in the cluster. @@ -421,8 +439,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] The error message. /// \param[in] The timestamp of the error. /// \return Status. - Status PushError(const JobID &job_id, const std::string &type, - const std::string &error_message, double timestamp); + Status PushError(const JobID &job_id, + const std::string &type, + const std::string &error_message, + double timestamp); /// Submit a normal task. /// @@ -436,8 +456,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// should capture parent's placement group implicilty. /// \return ObjectRefs returned by this task. std::vector SubmitTask( - const RayFunction &function, const std::vector> &args, - const TaskOptions &task_options, int max_retries, bool retry_exceptions, + const RayFunction &function, + const std::vector> &args, + const TaskOptions &task_options, + int max_retries, + bool retry_exceptions, const rpc::SchedulingStrategy &scheduling_strategy, const std::string &debugger_breakpoint); @@ -455,7 +478,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { Status CreateActor(const RayFunction &function, const std::vector> &args, const ActorCreationOptions &actor_creation_options, - const std::string &extension_data, ActorID *actor_id); + const std::string &extension_data, + ActorID *actor_id); /// Create a placement group. /// @@ -496,8 +520,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] task_options Options for this task. /// \return ObjectRefs returned by this task. std::optional> SubmitActorTask( - const ActorID &actor_id, const RayFunction &function, - const std::vector> &args, const TaskOptions &task_options); + const ActorID &actor_id, + const RayFunction &function, + const std::vector> &args, + const TaskOptions &task_options); /// Tell an actor to exit immediately, without completing outstanding work. /// @@ -546,7 +572,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// serialized actor handle in the language frontend is stored inside an /// object, then this must be recorded in the worker's ReferenceCounter. /// \return Status::Invalid if we don't have the specified handle. - Status SerializeActorHandle(const ActorID &actor_id, std::string *output, + Status SerializeActorHandle(const ActorID &actor_id, + std::string *output, ObjectID *actor_handle_id) const; /// @@ -583,7 +610,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// the current object is inlined, the task_output_inlined_bytes will be updated. /// \param[out] return_object RayObject containing buffers to write results into. /// \return Status. - Status AllocateReturnObject(const ObjectID &object_id, const size_t &data_size, + Status AllocateReturnObject(const ObjectID &object_id, + const size_t &data_size, const std::shared_ptr &metadata, const std::vector &contained_object_id, int64_t *task_output_inlined_bytes, @@ -646,7 +674,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// /// Implements gRPC server handler. - void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, + void HandlePushTask(const rpc::PushTaskRequest &request, + rpc::PushTaskReply *reply, rpc::SendReplyCallback send_reply_callback) override; /// Implements gRPC server handler. @@ -687,7 +716,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { rpc::SendReplyCallback send_reply_callback) override; /// Implements gRPC server handler. - void HandleKillActor(const rpc::KillActorRequest &request, rpc::KillActorReply *reply, + void HandleKillActor(const rpc::KillActorRequest &request, + rpc::KillActorReply *reply, rpc::SendReplyCallback send_reply_callback) override; /// Implements gRPC server handler. @@ -711,7 +741,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { rpc::SendReplyCallback send_reply_callback) override; /// Trigger local GC on this worker. - void HandleLocalGC(const rpc::LocalGCRequest &request, rpc::LocalGCReply *reply, + void HandleLocalGC(const rpc::LocalGCRequest &request, + rpc::LocalGCReply *reply, rpc::SendReplyCallback send_reply_callback) override; // Spill objects to external storage. @@ -736,7 +767,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { // Make the this worker exit. // This request fails if the core worker owns any object. - void HandleExit(const rpc::ExitRequest &request, rpc::ExitReply *reply, + void HandleExit(const rpc::ExitRequest &request, + rpc::ExitReply *reply, rpc::SendReplyCallback send_reply_callback) override; // Set local worker as the owner of object. @@ -763,7 +795,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] success_callback The callback to use the result object. /// \param[in] python_future the void* object to be passed to SetResultCallback /// \return void - void GetAsync(const ObjectID &object_id, SetResultCallback success_callback, + void GetAsync(const ObjectID &object_id, + SetResultCallback success_callback, void *python_future); // Get serialized job configuration. @@ -795,13 +828,21 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { const std::string &serialized_runtime_env_info); void BuildCommonTaskSpec( - TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, - const std::string &name, const TaskID ¤t_task_id, uint64_t task_index, - const TaskID &caller_id, const rpc::Address &address, const RayFunction &function, - const std::vector> &args, uint64_t num_returns, + TaskSpecBuilder &builder, + const JobID &job_id, + const TaskID &task_id, + const std::string &name, + const TaskID ¤t_task_id, + uint64_t task_index, + const TaskID &caller_id, + const rpc::Address &address, + const RayFunction &function, + const std::vector> &args, + uint64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const std::string &debugger_breakpoint, int64_t depth, + const std::string &debugger_breakpoint, + int64_t depth, const std::string &serialized_runtime_env_info, const std::string &concurrency_group_name = ""); void SetCurrentTaskId(const TaskID &task_id, uint64_t attempt_number); @@ -827,7 +868,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void InternalHeartbeat(); /// Helper method to fill in object status reply given an object. - void PopulateObjectStatus(const ObjectID &object_id, std::shared_ptr obj, + void PopulateObjectStatus(const ObjectID &object_id, + std::shared_ptr obj, rpc::GetObjectStatusReply *reply); /// @@ -874,7 +916,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { bool *is_application_level_error); /// Put an object in the local plasma store. - Status PutInLocalPlasmaStore(const RayObject &object, const ObjectID &object_id, + Status PutInLocalPlasmaStore(const RayObject &object, + const ObjectID &object_id, bool pin_object); /// Execute a local mode task (runs normal ExecuteTask) @@ -947,7 +990,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Process the subscribe message received from the subscriber. void ProcessSubscribeMessage(const rpc::SubMessage &sub_message, - rpc::ChannelType channel_type, const std::string &key_id, + rpc::ChannelType channel_type, + const std::string &key_id, const NodeID &subscriber_id); /// A single endpoint to process different types of pubsub commands. @@ -984,7 +1028,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Request the spillage of an object that we own from the primary that hosts /// the primary copy to spill. - void SpillOwnedObject(const ObjectID &object_id, const std::shared_ptr &obj, + void SpillOwnedObject(const ObjectID &object_id, + const std::shared_ptr &obj, std::function callback); const CoreWorkerOptions options_; @@ -1170,8 +1215,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { async_plasma_callbacks_ GUARDED_BY(plasma_mutex_); // Fallback for when GetAsync cannot directly get the requested object. - void PlasmaCallback(SetResultCallback success, std::shared_ptr ray_object, - ObjectID object_id, void *py_future); + void PlasmaCallback(SetResultCallback success, + std::shared_ptr ray_object, + ObjectID object_id, + void *py_future); /// we are shutting down and not running further tasks. /// when exiting_ is set to true HandlePushTask becomes no-op. diff --git a/src/ray/core_worker/core_worker_options.h b/src/ray/core_worker/core_worker_options.h index 56021c0bb..784024b85 100644 --- a/src/ray/core_worker/core_worker_options.h +++ b/src/ray/core_worker/core_worker_options.h @@ -33,11 +33,14 @@ struct CoreWorkerOptions { // Callback that must be implemented and provided by the language-specific worker // frontend to execute tasks and return their results. using TaskExecutionCallback = std::function &required_resources, const std::vector> &args, const std::vector &arg_refs, - const std::vector &return_ids, const std::string &debugger_breakpoint, + const std::vector &return_ids, + const std::string &debugger_breakpoint, std::vector> *results, std::shared_ptr &creation_task_exception_pb_bytes, bool *is_application_level_error, diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index 3385374a3..abc3dde6e 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -144,7 +144,8 @@ CoreWorkerProcessImpl::CoreWorkerProcessImpl(const CoreWorkerOptions &options) // Initialize event framework. if (RayConfig::instance().event_log_reporter_enabled() && !options_.log_dir.empty()) { RayEventInit(ray::rpc::Event_SourceType::Event_SourceType_CORE_WORKER, - absl::flat_hash_map(), options_.log_dir, + absl::flat_hash_map(), + options_.log_dir, RayConfig::instance().event_level()); } } @@ -188,7 +189,10 @@ void CoreWorkerProcessImpl::InitializeSystemConfig() { options_.raylet_ip_address, options_.node_manager_port, client_call_manager); raylet::RayletClient raylet_client(grpc_client); - std::function get_once = [this, &get_once, &raylet_client, &promise, + std::function get_once = [this, + &get_once, + &raylet_client, + &promise, &io_service](int64_t num_attempts) { raylet_client.GetSystemConfig( [this, num_attempts, &get_once, &promise, &io_service]( diff --git a/src/ray/core_worker/future_resolver.cc b/src/ray/core_worker/future_resolver.cc index a46d38d96..b3f68ee00 100644 --- a/src/ray/core_worker/future_resolver.cc +++ b/src/ray/core_worker/future_resolver.cc @@ -30,8 +30,9 @@ void FutureResolver::ResolveFutureAsync(const ObjectID &object_id, request.set_object_id(object_id.Binary()); request.set_owner_worker_id(owner_address.worker_id()); conn->GetObjectStatus( - request, [this, object_id, owner_address](const Status &status, - const rpc::GetObjectStatusReply &reply) { + request, + [this, object_id, owner_address](const Status &status, + const rpc::GetObjectStatusReply &reply) { ProcessResolvedObject(object_id, owner_address, status, reply); }); } @@ -97,7 +98,8 @@ void FutureResolver::ProcessResolvedObject(const ObjectID &object_id, VectorFromProtobuf(reply.object().nested_inlined_refs()); for (const auto &inlined_ref : inlined_refs) { reference_counter_->AddBorrowedObject(ObjectID::FromBinary(inlined_ref.object_id()), - object_id, inlined_ref.owner_address()); + object_id, + inlined_ref.owner_address()); } RAY_UNUSED(in_memory_store_->Put( RayObject(data_buffer, metadata_buffer, inlined_refs), object_id)); diff --git a/src/ray/core_worker/future_resolver.h b/src/ray/core_worker/future_resolver.h index 267451c00..bbf305a0b 100644 --- a/src/ray/core_worker/future_resolver.h +++ b/src/ray/core_worker/future_resolver.h @@ -60,7 +60,8 @@ class FutureResolver { /// \param[in] object_id The ID of the future to resolve. /// \param[in] status Any error code from the owner obtaining the object status. /// \param[in] object_status The object status. - void ProcessResolvedObject(const ObjectID &object_id, const rpc::Address &owner_address, + void ProcessResolvedObject(const ObjectID &object_id, + const rpc::Address &owner_address, const Status &status, const rpc::GetObjectStatusReply &object_status); diff --git a/src/ray/core_worker/gcs_server_address_updater.cc b/src/ray/core_worker/gcs_server_address_updater.cc index f89f6fd28..e0c90878a 100644 --- a/src/ray/core_worker/gcs_server_address_updater.cc +++ b/src/ray/core_worker/gcs_server_address_updater.cc @@ -18,11 +18,12 @@ namespace ray { namespace core { GcsServerAddressUpdater::GcsServerAddressUpdater( - const std::string raylet_ip_address, const int port, + const std::string raylet_ip_address, + const int port, std::function update_func) : client_call_manager_(updater_io_service_), - raylet_client_(rpc::NodeManagerWorkerClient::make(raylet_ip_address, port, - client_call_manager_)), + raylet_client_(rpc::NodeManagerWorkerClient::make( + raylet_ip_address, port, client_call_manager_)), update_func_(update_func), updater_runner_(updater_io_service_), updater_thread_([this] { diff --git a/src/ray/core_worker/gcs_server_address_updater.h b/src/ray/core_worker/gcs_server_address_updater.h index b1f9e84c4..5dbae641b 100644 --- a/src/ray/core_worker/gcs_server_address_updater.h +++ b/src/ray/core_worker/gcs_server_address_updater.h @@ -28,7 +28,8 @@ class GcsServerAddressUpdater { /// \param raylet_ip_address Raylet ip address. /// \param port Port to connect raylet. /// \param address to store gcs server address. - GcsServerAddressUpdater(const std::string raylet_ip_address, const int port, + GcsServerAddressUpdater(const std::string raylet_ip_address, + const int port, std::function update_func); ~GcsServerAddressUpdater(); diff --git a/src/ray/core_worker/lease_policy.h b/src/ray/core_worker/lease_policy.h index fbf963c99..6634f78a8 100644 --- a/src/ray/core_worker/lease_policy.h +++ b/src/ray/core_worker/lease_policy.h @@ -56,7 +56,8 @@ class LocalityAwareLeasePolicy : public LeasePolicyInterface { public: LocalityAwareLeasePolicy( std::shared_ptr locality_data_provider, - NodeAddrFactory node_addr_factory, const rpc::Address fallback_rpc_address) + NodeAddrFactory node_addr_factory, + const rpc::Address fallback_rpc_address) : locality_data_provider_(locality_data_provider), node_addr_factory_(node_addr_factory), fallback_rpc_address_(fallback_rpc_address) {} diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 3623911fb..583992fdf 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -48,7 +48,8 @@ inline gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env, jobject gcs_client_ } } -jobject ToJavaArgs(JNIEnv *env, jbooleanArray java_check_results, +jobject ToJavaArgs(JNIEnv *env, + jbooleanArray java_check_results, const std::vector> &args) { if (java_check_results == nullptr) { // If `java_check_results` is null, it means that `checkByteBufferArguments` @@ -58,7 +59,8 @@ jobject ToJavaArgs(JNIEnv *env, jbooleanArray java_check_results, jboolean *check_results = env->GetBooleanArrayElements(java_check_results, nullptr); size_t i = 0; jobject args_array_list = NativeVectorToJavaList>( - env, args, + env, + args, [check_results, &i](JNIEnv *env, const std::shared_ptr &native_object) { if (*(check_results + (i++))) { @@ -94,17 +96,31 @@ JNIEnv *GetJNIEnv() { extern "C" { #endif -JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( - JNIEnv *env, jclass, jint workerMode, jstring nodeIpAddress, jint nodeManagerPort, - jstring driverName, jstring storeSocket, jstring rayletSocket, jbyteArray jobId, - jobject gcsClientOptions, jint numWorkersPerProcess, jstring logDir, - jbyteArray jobConfig, jint startupToken, jint runtimeEnvHash) { +JNIEXPORT void JNICALL +Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env, + jclass, + jint workerMode, + jstring nodeIpAddress, + jint nodeManagerPort, + jstring driverName, + jstring storeSocket, + jstring rayletSocket, + jbyteArray jobId, + jobject gcsClientOptions, + jint numWorkersPerProcess, + jstring logDir, + jbyteArray jobConfig, + jint startupToken, + jint runtimeEnvHash) { auto task_execution_callback = - [](TaskType task_type, const std::string task_name, const RayFunction &ray_function, + [](TaskType task_type, + const std::string task_name, + const RayFunction &ray_function, const std::unordered_map &required_resources, const std::vector> &args, const std::vector &arg_refs, - const std::vector &return_ids, const std::string &debugger_breakpoint, + const std::vector &return_ids, + const std::string &debugger_breakpoint, std::vector> *results, std::shared_ptr &creation_task_exception_pb, bool *is_application_level_error, @@ -139,17 +155,18 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( // convert args // TODO (kfstorm): Avoid copying binary data from Java to C++ - jbooleanArray java_check_results = - static_cast(env->CallObjectMethod( - java_task_executor, java_task_executor_parse_function_arguments, - ray_function_array_list)); + jbooleanArray java_check_results = static_cast( + env->CallObjectMethod(java_task_executor, + java_task_executor_parse_function_arguments, + ray_function_array_list)); RAY_CHECK_JAVA_EXCEPTION(env); jobject args_array_list = ToJavaArgs(env, java_check_results, args); // invoke Java method - jobject java_return_objects = - env->CallObjectMethod(java_task_executor, java_task_executor_execute, - ray_function_array_list, args_array_list); + jobject java_return_objects = env->CallObjectMethod(java_task_executor, + java_task_executor_execute, + ray_function_array_list, + args_array_list); // Check whether the exception is `IntentionalSystemExit`. jthrowable throwable = env->ExceptionOccurred(); if (throwable) { @@ -173,7 +190,9 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( if (!return_ids.empty()) { std::vector> return_objects; JavaListToNativeVector>( - env, java_return_objects, &return_objects, + env, + java_return_objects, + &return_objects, [](JNIEnv *env, jobject java_native_ray_object) { return JavaNativeRayObjectToNativeRayObject(env, java_native_ray_object); }); @@ -190,14 +209,19 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( auto result_ptr = &(*results)[0]; RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().AllocateReturnObject( - result_id, data_size, metadata, contained_object_ids, - &task_output_inlined_bytes, result_ptr)); + result_id, + data_size, + metadata, + contained_object_ids, + &task_output_inlined_bytes, + result_ptr)); // A nullptr is returned if the object already exists. auto result = *result_ptr; if (result != nullptr) { if (result->HasData()) { - memcpy(result->GetData()->Data(), return_objects[i]->GetData()->Data(), + memcpy(result->GetData()->Data(), + return_objects[i]->GetData()->Data(), data_size); } } @@ -238,7 +262,8 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( auto worker_id_bytes = IdToJavaByteArray(env, worker_id); if (java_task_executor) { env->CallVoidMethod(java_task_executor, - java_native_task_executor_on_worker_shutdown, worker_id_bytes); + java_native_task_executor_on_worker_shutdown, + worker_id_bytes); RAY_CHECK_JAVA_EXCEPTION(env); } }; @@ -293,7 +318,8 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEn } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_RayNativeRuntime_nativeGetActorIdOfNamedActor(JNIEnv *env, jclass, +Java_io_ray_runtime_RayNativeRuntime_nativeGetActorIdOfNamedActor(JNIEnv *env, + jclass, jstring actor_name, jstring ray_namespace) { const char *native_actor_name = env->GetStringUTFChars(actor_name, JNI_FALSE); @@ -316,7 +342,8 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeKillActor( JNIEnv *env, jclass, jbyteArray actorId, jboolean noRestart) { auto status = CoreWorkerProcess::GetCoreWorker().KillActor( JavaByteArrayToId(env, actorId), - /*force_kill=*/true, noRestart); + /*force_kill=*/true, + noRestart); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } @@ -335,14 +362,15 @@ Java_io_ray_runtime_RayNativeRuntime_nativeGetResourceIds(JNIEnv *env, jclass) { [](JNIEnv *env, const std::vector> &value) -> jobject { auto elem_converter = [](JNIEnv *env, const std::pair &elem) -> jobject { - jobject java_item = - env->NewObject(java_resource_value_class, java_resource_value_init, - (jlong)elem.first, (jdouble)elem.second); + jobject java_item = env->NewObject(java_resource_value_class, + java_resource_value_init, + (jlong)elem.first, + (jdouble)elem.second); RAY_CHECK_JAVA_EXCEPTION(env); return java_item; }; - return NativeVectorToJavaList>(env, value, - std::move(elem_converter)); + return NativeVectorToJavaList>( + env, value, std::move(elem_converter)); }; ResourceMappingType resource_mapping = CoreWorkerProcess::GetCoreWorker().GetResourceIDs(); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h index 0ffa3b91b..7d8e6f85b 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h @@ -27,9 +27,21 @@ extern "C" { * Signature: * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;[BII)V */ -JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( - JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject, - jint, jstring, jbyteArray, jint, jint); +JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *, + jclass, + jint, + jstring, + jint, + jstring, + jstring, + jstring, + jbyteArray, + jobject, + jint, + jstring, + jbyteArray, + jint, + jint); /* * Class: io_ray_runtime_RayNativeRuntime @@ -63,8 +75,10 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeKillActor(JNIE * Signature: (Ljava/lang/String;Ljava/lang/String;)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_RayNativeRuntime_nativeGetActorIdOfNamedActor(JNIEnv *, jclass, - jstring, jstring); +Java_io_ray_runtime_RayNativeRuntime_nativeGetActorIdOfNamedActor(JNIEnv *, + jclass, + jstring, + jstring); /* * Class: io_ray_runtime_RayNativeRuntime diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.cc b/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.cc index 97886a596..ae02431f5 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.cc @@ -52,14 +52,17 @@ JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_actor_NativeActorHandle_nativeS ObjectID actor_handle_id; Status status = CoreWorkerProcess::GetCoreWorker().SerializeActorHandle( actor_id, &output, &actor_handle_id); - env->SetByteArrayRegion(actorHandleId, 0, ObjectID::kLength, + env->SetByteArrayRegion(actorHandleId, + 0, + ObjectID::kLength, reinterpret_cast(actor_handle_id.Data())); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return NativeStringToJavaByteArray(env, output); } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_actor_NativeActorHandle_nativeDeserialize(JNIEnv *env, jclass o, +Java_io_ray_runtime_actor_NativeActorHandle_nativeDeserialize(JNIEnv *env, + jclass o, jbyteArray data) { auto buffer = JavaByteArrayToNativeBuffer(env, data); RAY_CHECK(buffer->Size() > 0); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.h b/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.h index b549096c9..ed3f14227 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.h @@ -52,7 +52,8 @@ JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_actor_NativeActorHandle_nativeS * Signature: ([B)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_actor_NativeActorHandle_nativeDeserialize(JNIEnv *, jclass, +Java_io_ray_runtime_actor_NativeActorHandle_nativeDeserialize(JNIEnv *, + jclass, jbyteArray); /* diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc index 68f00e12e..9dc333e31 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc @@ -38,8 +38,8 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeCreateGlobalStateAccessor( std::vector results; boost::split(results, bootstrap_address, boost::is_any_of(":")); RAY_CHECK(results.size() == 2); - ray::gcs::GcsClientOptions client_options(results[0], std::stoi(results[1]), - redis_password); + ray::gcs::GcsClientOptions client_options( + results[0], std::stoi(results[1]), redis_password); gcs_accessor = new gcs::GlobalStateAccessor(client_options); } return reinterpret_cast(gcs_accessor); @@ -69,7 +69,8 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetA } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetNextJobID(JNIEnv *env, jobject o, +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetNextJobID(JNIEnv *env, + jobject o, jlong gcs_accessor_ptr) { auto *gcs_accessor = reinterpret_cast(gcs_accessor_ptr); const auto &job_id = gcs_accessor->GetNextJobID(); @@ -77,7 +78,8 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetNextJobID(JNIEnv *env, jobj } JNIEXPORT jobject JNICALL -Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllNodeInfo(JNIEnv *env, jobject o, +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllNodeInfo(JNIEnv *env, + jobject o, jlong gcs_accessor_ptr) { auto *gcs_accessor = reinterpret_cast(gcs_accessor_ptr); auto node_info_list = gcs_accessor->GetAllNodeInfo(); @@ -108,7 +110,8 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllActorInfo( } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorInfo(JNIEnv *env, jobject o, +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorInfo(JNIEnv *env, + jobject o, jlong gcs_accessor_ptr, jbyteArray actorId) { const auto actor_id = JavaByteArrayToId(env, actorId); @@ -158,9 +161,8 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllPlacementGroupInfo( } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetInternalKV(JNIEnv *env, jobject o, - jlong gcs_accessor_ptr, - jstring n, jstring k) { +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetInternalKV( + JNIEnv *env, jobject o, jlong gcs_accessor_ptr, jstring n, jstring k) { std::string key = JavaStringToNativeString(env, k); std::string ns = JavaStringToNativeString(env, n); auto *gcs_accessor = reinterpret_cast(gcs_accessor_ptr); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h index 6251d8d07..94ef5df14 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h @@ -72,7 +72,8 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetNextJobID(JNIEnv *, jobject * Signature: (J)Ljava/util/List; */ JNIEXPORT jobject JNICALL -Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllNodeInfo(JNIEnv *, jobject, +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllNodeInfo(JNIEnv *, + jobject, jlong); /* @@ -81,8 +82,10 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllNodeInfo(JNIEnv *, jobje * Signature: (J[B)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetNodeResourceInfo(JNIEnv *, jobject, - jlong, jbyteArray); +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetNodeResourceInfo(JNIEnv *, + jobject, + jlong, + jbyteArray); /* * Class: io_ray_runtime_gcs_GlobalStateAccessor @@ -90,7 +93,8 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetNodeResourceInfo(JNIEnv *, * Signature: (J)Ljava/util/List; */ JNIEXPORT jobject JNICALL -Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllActorInfo(JNIEnv *, jobject, +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllActorInfo(JNIEnv *, + jobject, jlong); /* @@ -99,7 +103,9 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllActorInfo(JNIEnv *, jobj * Signature: (J[B)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorInfo(JNIEnv *, jobject, jlong, +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorInfo(JNIEnv *, + jobject, + jlong, jbyteArray); /* @@ -108,7 +114,8 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorInfo(JNIEnv *, jobject * Signature: (J[B)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetPlacementGroupInfo(JNIEnv *, jobject, +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetPlacementGroupInfo(JNIEnv *, + jobject, jlong, jbyteArray); @@ -137,8 +144,8 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllPlacementGroupInfo(JNIEn * Signature: (JLjava/lang/String;Ljava/lang/String;)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetInternalKV(JNIEnv *, jobject, jlong, - jstring, jstring); +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetInternalKV( + JNIEnv *, jobject, jlong, jstring, jstring); /* * Class: io_ray_runtime_gcs_GlobalStateAccessor diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_metric_NativeMetric.cc b/src/ray/core_worker/lib/java/io_ray_runtime_metric_NativeMetric.cc index 2803a3b9a..fcfd0a7a7 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_metric_NativeMetric.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_metric_NativeMetric.cc @@ -34,10 +34,15 @@ using TagsType = std::vector>; /// \param[out] description metric description in native string. /// \param[out] unit metric measurement unit in native string. /// \param[out] tag_keys metric tag key vector unit in native vector. -inline void MetricTransform(JNIEnv *env, jstring j_name, jstring j_description, - jstring j_unit, jobject tag_key_list, - std::string *metric_name, std::string *description, - std::string *unit, std::vector &tag_keys) { +inline void MetricTransform(JNIEnv *env, + jstring j_name, + jstring j_description, + jstring j_unit, + jobject tag_key_list, + std::string *metric_name, + std::string *description, + std::string *unit, + std::vector &tag_keys) { *metric_name = JavaStringToNativeString(env, static_cast(j_name)); *description = JavaStringToNativeString(env, static_cast(j_description)); *unit = JavaStringToNativeString(env, static_cast(j_unit)); @@ -46,7 +51,8 @@ inline void MetricTransform(JNIEnv *env, jstring j_name, jstring j_description, // We just call TagKeyType::Register to get tag object since opencensus tags // registry is thread-safe and registry can return a new tag or registered // item when it already exists. - std::transform(tag_key_str_list.begin(), tag_key_str_list.end(), + std::transform(tag_key_str_list.begin(), + tag_key_str_list.end(), std::back_inserter(tag_keys), [](std::string &tag_key) { return TagKeyType::Register(tag_key); }); } @@ -61,54 +67,99 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_metric_NativeMetric_registerTagkeyNat RAY_IGNORE_EXPR(TagKeyType::Register(tag_key_name)); } -JNIEXPORT jlong JNICALL Java_io_ray_runtime_metric_NativeMetric_registerGaugeNative( - JNIEnv *env, jclass obj, jstring j_name, jstring j_description, jstring j_unit, - jobject tag_key_list) { +JNIEXPORT jlong JNICALL +Java_io_ray_runtime_metric_NativeMetric_registerGaugeNative(JNIEnv *env, + jclass obj, + jstring j_name, + jstring j_description, + jstring j_unit, + jobject tag_key_list) { std::string metric_name; std::string description; std::string unit; std::vector tag_keys; - MetricTransform(env, j_name, j_description, j_unit, tag_key_list, &metric_name, - &description, &unit, tag_keys); + MetricTransform(env, + j_name, + j_description, + j_unit, + tag_key_list, + &metric_name, + &description, + &unit, + tag_keys); auto *gauge = new stats::Gauge(metric_name, description, unit, tag_keys); return reinterpret_cast(gauge); } -JNIEXPORT jlong JNICALL Java_io_ray_runtime_metric_NativeMetric_registerCountNative( - JNIEnv *env, jclass obj, jstring j_name, jstring j_description, jstring j_unit, - jobject tag_key_list) { +JNIEXPORT jlong JNICALL +Java_io_ray_runtime_metric_NativeMetric_registerCountNative(JNIEnv *env, + jclass obj, + jstring j_name, + jstring j_description, + jstring j_unit, + jobject tag_key_list) { std::string metric_name; std::string description; std::string unit; std::vector tag_keys; - MetricTransform(env, j_name, j_description, j_unit, tag_key_list, &metric_name, - &description, &unit, tag_keys); + MetricTransform(env, + j_name, + j_description, + j_unit, + tag_key_list, + &metric_name, + &description, + &unit, + tag_keys); auto *count = new stats::Count(metric_name, description, unit, tag_keys); return reinterpret_cast(count); } -JNIEXPORT jlong JNICALL Java_io_ray_runtime_metric_NativeMetric_registerSumNative( - JNIEnv *env, jclass obj, jstring j_name, jstring j_description, jstring j_unit, - jobject tag_key_list) { +JNIEXPORT jlong JNICALL +Java_io_ray_runtime_metric_NativeMetric_registerSumNative(JNIEnv *env, + jclass obj, + jstring j_name, + jstring j_description, + jstring j_unit, + jobject tag_key_list) { std::string metric_name; std::string description; std::string unit; std::vector tag_keys; - MetricTransform(env, j_name, j_description, j_unit, tag_key_list, &metric_name, - &description, &unit, tag_keys); + MetricTransform(env, + j_name, + j_description, + j_unit, + tag_key_list, + &metric_name, + &description, + &unit, + tag_keys); auto *sum = new stats::Sum(metric_name, description, unit, tag_keys); return reinterpret_cast(sum); } -JNIEXPORT jlong JNICALL Java_io_ray_runtime_metric_NativeMetric_registerHistogramNative( - JNIEnv *env, jclass obj, jstring j_name, jstring j_description, jstring j_unit, - jdoubleArray j_boundaries, jobject tag_key_list) { +JNIEXPORT jlong JNICALL +Java_io_ray_runtime_metric_NativeMetric_registerHistogramNative(JNIEnv *env, + jclass obj, + jstring j_name, + jstring j_description, + jstring j_unit, + jdoubleArray j_boundaries, + jobject tag_key_list) { std::string metric_name; std::string description; std::string unit; std::vector tag_keys; - MetricTransform(env, j_name, j_description, j_unit, tag_key_list, &metric_name, - &description, &unit, tag_keys); + MetricTransform(env, + j_name, + j_description, + j_unit, + tag_key_list, + &metric_name, + &description, + &unit, + tag_keys); std::vector boundaries; JavaDoubleArrayToNativeDoubleVector(env, j_boundaries, &boundaries); @@ -124,9 +175,13 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_metric_NativeMetric_unregisterMetricN delete metric; } -JNIEXPORT void JNICALL Java_io_ray_runtime_metric_NativeMetric_recordNative( - JNIEnv *env, jclass obj, jlong metric_native_pointer, jdouble value, - jobject tag_key_list, jobject tag_value_list) { +JNIEXPORT void JNICALL +Java_io_ray_runtime_metric_NativeMetric_recordNative(JNIEnv *env, + jclass obj, + jlong metric_native_pointer, + jdouble value, + jobject tag_key_list, + jobject tag_value_list) { stats::Metric *metric = reinterpret_cast(metric_native_pointer); std::vector tag_key_str_list; std::vector tag_value_str_list; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc index 887966ee7..26c8d9029 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc @@ -22,8 +22,11 @@ #include "ray/core_worker/core_worker.h" #include "ray/gcs/gcs_client/global_state_accessor.h" -Status PutSerializedObject(JNIEnv *env, jobject obj, ObjectID object_id, - ObjectID *out_object_id, bool pin_object = true, +Status PutSerializedObject(JNIEnv *env, + jobject obj, + ObjectID object_id, + ObjectID *out_object_id, + bool pin_object = true, const std::unique_ptr &owner_address = nullptr) { auto native_ray_object = JavaNativeRayObjectToNativeRayObject(env, obj); RAY_CHECK(native_ray_object != nullptr); @@ -39,13 +42,20 @@ Status PutSerializedObject(JNIEnv *env, jobject obj, ObjectID object_id, nested_ids.push_back(ObjectID::FromBinary(ref.object_id())); } status = CoreWorkerProcess::GetCoreWorker().CreateOwnedAndIncrementLocalRef( - native_ray_object->GetMetadata(), data_size, nested_ids, out_object_id, &data, + native_ray_object->GetMetadata(), + data_size, + nested_ids, + out_object_id, + &data, /*created_by_worker=*/true, /*owner_address=*/owner_address); } else { status = CoreWorkerProcess::GetCoreWorker().CreateExisting( - native_ray_object->GetMetadata(), data_size, object_id, - CoreWorkerProcess::GetCoreWorker().GetRpcAddress(), &data, + native_ray_object->GetMetadata(), + data_size, + object_id, + CoreWorkerProcess::GetCoreWorker().GetRpcAddress(), + &data, /*created_by_worker=*/true); *out_object_id = object_id; } @@ -84,8 +94,11 @@ Java_io_ray_runtime_object_NativeObjectStore_nativePut__Lio_ray_runtime_object_N owner_address->ParseFromString( JavaByteArrayToNativeString(env, serialized_owner_actor_address_bytes)); } - auto status = PutSerializedObject(env, obj, /*object_id=*/ObjectID::Nil(), - /*out_object_id=*/&object_id, /*pin_object=*/true, + auto status = PutSerializedObject(env, + obj, + /*object_id=*/ObjectID::Nil(), + /*out_object_id=*/&object_id, + /*pin_object=*/true, /*owner_address=*/owner_address); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return IdToJavaByteArray(env, object_id); @@ -96,9 +109,11 @@ Java_io_ray_runtime_object_NativeObjectStore_nativePut___3BLio_ray_runtime_objec JNIEnv *env, jclass, jbyteArray objectId, jobject obj) { auto object_id = JavaByteArrayToId(env, objectId); ObjectID dummy_object_id; - auto status = - PutSerializedObject(env, obj, object_id, - /*out_object_id=*/&dummy_object_id, /*pin_object=*/true); + auto status = PutSerializedObject(env, + obj, + object_id, + /*out_object_id=*/&dummy_object_id, + /*pin_object=*/true); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } @@ -116,9 +131,13 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeGet env, results, NativeRayObjectToJavaNativeRayObject); } -JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWait( - JNIEnv *env, jclass, jobject objectIds, jint numObjects, jlong timeoutMs, - jboolean fetch_local) { +JNIEXPORT jobject JNICALL +Java_io_ray_runtime_object_NativeObjectStore_nativeWait(JNIEnv *env, + jclass, + jobject objectIds, + jint numObjects, + jlong timeoutMs, + jboolean fetch_local) { std::vector object_ids; JavaListToNativeVector( env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { @@ -176,7 +195,8 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetAllReferenceCounts(JNIEnv jclass) { auto reference_counts = CoreWorkerProcess::GetCoreWorker().GetAllReferenceCounts(); return NativeMapToJavaMap>( - env, reference_counts, + env, + reference_counts, [](JNIEnv *env, const ObjectID &key) { return IdToJavaByteArray(env, key); }, @@ -191,7 +211,8 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetAllReferenceCounts(JNIEnv } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *env, jclass, +Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *env, + jclass, jbyteArray objectId) { auto object_id = JavaByteArrayToId(env, objectId); const auto &rpc_address = CoreWorkerProcess::GetCoreWorker().GetOwnerAddress(object_id); @@ -199,14 +220,15 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *env, } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *env, jclass, +Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *env, + jclass, jbyteArray objectId) { auto object_id = JavaByteArrayToId(env, objectId); rpc::Address address; // TODO(ekl) send serialized object status to Java land. std::string serialized_object_status; - CoreWorkerProcess::GetCoreWorker().GetOwnershipInfo(object_id, &address, - &serialized_object_status); + CoreWorkerProcess::GetCoreWorker().GetOwnershipInfo( + object_id, &address, &serialized_object_status); auto address_str = address.SerializeAsString(); auto arr = NativeStringToJavaByteArray(env, address_str); return arr; @@ -214,7 +236,10 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *env, JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeRegisterOwnershipInfoAndResolveFuture( - JNIEnv *env, jclass, jbyteArray objectId, jbyteArray outerObjectId, + JNIEnv *env, + jclass, + jbyteArray objectId, + jbyteArray outerObjectId, jbyteArray ownerAddress) { auto object_id = JavaByteArrayToId(env, objectId); auto outer_objectId = ObjectID::Nil(); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h index 9358f4473..de8bf810e 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h @@ -71,7 +71,8 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeDelete * Signature: ([B[B)V */ JNIEXPORT void JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativeAddLocalReference(JNIEnv *, jclass, +Java_io_ray_runtime_object_NativeObjectStore_nativeAddLocalReference(JNIEnv *, + jclass, jbyteArray, jbyteArray); @@ -81,7 +82,8 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeAddLocalReference(JNIEnv *, j * Signature: ([B[B)V */ JNIEXPORT void JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativeRemoveLocalReference(JNIEnv *, jclass, +Java_io_ray_runtime_object_NativeObjectStore_nativeRemoveLocalReference(JNIEnv *, + jclass, jbyteArray, jbyteArray); @@ -100,7 +102,8 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetAllReferenceCounts(JNIEnv * Signature: ([B)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *, jclass, +Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *, + jclass, jbyteArray); /* @@ -109,7 +112,8 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *, jcl * Signature: ([B)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *, jclass, +Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnershipInfo(JNIEnv *, + jclass, jbyteArray); /* diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index f79fdee8e..f98f48719 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -32,7 +32,8 @@ inline jint GetHashCodeOfJavaObject(JNIEnv *env, jobject java_object) { thread_local std::unordered_map>> submitter_function_descriptor_cache; -inline const RayFunction &ToRayFunction(JNIEnv *env, jobject functionDescriptor, +inline const RayFunction &ToRayFunction(JNIEnv *env, + jobject functionDescriptor, jint hash) { auto &fd_vector = submitter_function_descriptor_cache[hash]; for (auto &pair : fd_vector) { @@ -89,7 +90,8 @@ inline std::vector> ToTaskArgs(JNIEnv *env, jobject arg inline std::unordered_map ToResources(JNIEnv *env, jobject java_resources) { return JavaMapToNativeMap( - env, java_resources, + env, + java_resources, [](JNIEnv *env, jobject java_key) { return JavaStringToNativeString(env, (jstring)java_key); }, @@ -197,7 +199,9 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env, actorCreationOptions, java_actor_creation_options_concurrency_groups); RAY_CHECK(java_concurrency_groups_field != nullptr); JavaListToNativeVector( - env, java_concurrency_groups_field, &concurrency_groups, + env, + java_concurrency_groups_field, + &concurrency_groups, [](JNIEnv *env, jobject java_concurrency_group_impl) { RAY_CHECK(java_concurrency_group_impl != nullptr); jobject java_func_descriptors = @@ -206,7 +210,9 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env, RAY_CHECK_JAVA_EXCEPTION(env); std::vector native_func_descriptors; JavaListToNativeVector( - env, java_func_descriptors, &native_func_descriptors, + env, + java_func_descriptors, + &native_func_descriptors, [](JNIEnv *env, jobject java_func_descriptor) { RAY_CHECK(java_func_descriptor != nullptr); const jint hashcode = GetHashCodeOfJavaObject(env, java_func_descriptor); @@ -217,12 +223,13 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env, }); // Put func_descriptors into this task group. const std::string concurrency_group_name = JavaStringToNativeString( - env, (jstring)env->GetObjectField(java_concurrency_group_impl, - java_concurrency_group_impl_name)); + env, + (jstring)env->GetObjectField(java_concurrency_group_impl, + java_concurrency_group_impl_name)); const uint32_t max_concurrency = env->GetIntField( java_concurrency_group_impl, java_concurrency_group_impl_max_concurrency); - return ray::ConcurrencyGroup{concurrency_group_name, max_concurrency, - native_func_descriptors}; + return ray::ConcurrencyGroup{ + concurrency_group_name, max_concurrency, native_func_descriptors}; }); auto java_serialized_runtime_env = (jstring)env->GetObjectField( actorCreationOptions, java_actor_creation_options_serialized_runtime_env); @@ -299,7 +306,8 @@ inline PlacementGroupCreationOptions ToPlacementGroupCreationOptions( JavaListToNativeVector>( env, java_bundles, &bundles, [](JNIEnv *env, jobject java_bundle) { return JavaMapToNativeMap( - env, java_bundle, + env, + java_bundle, [](JNIEnv *env, jobject java_key) { return JavaStringToNativeString(env, (jstring)java_key); }, @@ -309,7 +317,9 @@ inline PlacementGroupCreationOptions ToPlacementGroupCreationOptions( return value; }); }); - return PlacementGroupCreationOptions(name, ConvertStrategy(java_strategy), bundles, + return PlacementGroupCreationOptions(name, + ConvertStrategy(java_strategy), + bundles, /*is_detached=*/false); } @@ -317,9 +327,14 @@ inline PlacementGroupCreationOptions ToPlacementGroupCreationOptions( extern "C" { #endif -JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask( - JNIEnv *env, jclass p, jobject functionDescriptor, jint functionDescriptorHash, - jobject args, jint numReturns, jobject callOptions) { +JNIEXPORT jobject JNICALL +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask(JNIEnv *env, + jclass p, + jobject functionDescriptor, + jint functionDescriptorHash, + jobject args, + jint numReturns, + jobject callOptions) { const auto &ray_function = ToRayFunction(env, functionDescriptor, functionDescriptorHash); auto task_args = ToTaskArgs(env, args); @@ -339,7 +354,9 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSub } // TODO (kfstorm): Allow setting `max_retries` via `CallOptions`. auto return_refs = - CoreWorkerProcess::GetCoreWorker().SubmitTask(ray_function, task_args, task_options, + CoreWorkerProcess::GetCoreWorker().SubmitTask(ray_function, + task_args, + task_options, /*max_retries=*/0, /*retry_exceptions=*/false, /*scheduling_strategy=*/ @@ -360,17 +377,23 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSub JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor( - JNIEnv *env, jclass p, jobject functionDescriptor, jint functionDescriptorHash, - jobject args, jobject actorCreationOptions) { + JNIEnv *env, + jclass p, + jobject functionDescriptor, + jint functionDescriptorHash, + jobject args, + jobject actorCreationOptions) { const auto &ray_function = ToRayFunction(env, functionDescriptor, functionDescriptorHash); auto task_args = ToTaskArgs(env, args); auto actor_creation_options = ToActorCreationOptions(env, actorCreationOptions); ActorID actor_id; - auto status = CoreWorkerProcess::GetCoreWorker().CreateActor( - ray_function, task_args, actor_creation_options, - /*extension_data*/ "", &actor_id); + auto status = CoreWorkerProcess::GetCoreWorker().CreateActor(ray_function, + task_args, + actor_creation_options, + /*extension_data*/ "", + &actor_id); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return IdToJavaByteArray(env, actor_id); @@ -378,8 +401,14 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor( JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask( - JNIEnv *env, jclass p, jbyteArray actorId, jobject functionDescriptor, - jint functionDescriptorHash, jobject args, jint numReturns, jobject callOptions) { + JNIEnv *env, + jclass p, + jbyteArray actorId, + jobject functionDescriptor, + jint functionDescriptorHash, + jobject args, + jint numReturns, + jobject callOptions) { auto actor_id = JavaByteArrayToId(env, actorId); const auto &ray_function = ToRayFunction(env, functionDescriptor, functionDescriptorHash); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h index 3bd38be4b..a7f416a8f 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h @@ -37,8 +37,8 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSub * (Lio/ray/runtime/functionmanager/FunctionDescriptor;ILjava/util/List;Lio/ray/api/options/ActorCreationOptions;)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(JNIEnv *, jclass, jobject, - jint, jobject, jobject); +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor( + JNIEnv *, jclass, jobject, jint, jobject, jobject); /* * Class: io_ray_runtime_task_NativeTaskSubmitter @@ -47,10 +47,8 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(JNIEnv *, jclass, * ([BLio/ray/runtime/functionmanager/FunctionDescriptor;ILjava/util/List;ILio/ray/api/options/CallOptions;)Ljava/util/List; */ JNIEXPORT jobject JNICALL -Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jclass, - jbyteArray, jobject, - jint, jobject, jint, - jobject); +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask( + JNIEnv *, jclass, jbyteArray, jobject, jint, jobject, jint, jobject); /* * Class: io_ray_runtime_task_NativeTaskSubmitter @@ -58,7 +56,8 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jcl * Signature: (Lio/ray/api/options/PlacementGroupCreationOptions;)[B */ JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(JNIEnv *, jclass, +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(JNIEnv *, + jclass, jobject); /* @@ -67,7 +66,8 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(JNIEnv * * Signature: ([B)V */ JNIEXPORT void JNICALL -Java_io_ray_runtime_task_NativeTaskSubmitter_nativeRemovePlacementGroup(JNIEnv *, jclass, +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeRemovePlacementGroup(JNIEnv *, + jclass, jbyteArray); /* diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 5fb14d6e0..eab6968cc 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -248,7 +248,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_jni_exception_util_class = LoadClass(env, "io/ray/runtime/util/JniExceptionUtil"); java_jni_exception_util_get_stack_trace = env->GetStaticMethodID( - java_jni_exception_util_class, "getStackTrace", + java_jni_exception_util_class, + "getStackTrace", "(Ljava/lang/String;ILjava/lang/String;Ljava/lang/Throwable;)Ljava/lang/String;"); java_base_id_class = LoadClass(env, "io/ray/api/id/BaseId"); @@ -263,7 +264,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_function_descriptor_class = LoadClass(env, "io/ray/runtime/functionmanager/FunctionDescriptor"); java_function_descriptor_get_language = - env->GetMethodID(java_function_descriptor_class, "getLanguage", + env->GetMethodID(java_function_descriptor_class, + "getLanguage", "()Lio/ray/runtime/generated/Common$Language;"); java_function_descriptor_to_list = env->GetMethodID(java_function_descriptor_class, "toList", "()Ljava/util/List;"); @@ -275,10 +277,11 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_function_arg_id = env->GetFieldID(java_function_arg_class, "id", "Lio/ray/api/id/ObjectId;"); java_function_arg_owner_address = - env->GetFieldID(java_function_arg_class, "ownerAddress", + env->GetFieldID(java_function_arg_class, + "ownerAddress", "Lio/ray/runtime/generated/Common$Address;"); - java_function_arg_value = env->GetFieldID(java_function_arg_class, "value", - "Lio/ray/runtime/object/NativeRayObject;"); + java_function_arg_value = env->GetFieldID( + java_function_arg_class, "value", "Lio/ray/runtime/object/NativeRayObject;"); java_base_task_options_class = LoadClass(env, "io/ray/api/options/BaseTaskOptions"); java_base_task_options_resources = @@ -296,8 +299,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_placement_group_class = LoadClass(env, "io/ray/runtime/placementgroup/PlacementGroupImpl"); - java_placement_group_id = env->GetFieldID(java_placement_group_class, "id", - "Lio/ray/api/id/PlacementGroupId;"); + java_placement_group_id = env->GetFieldID( + java_placement_group_class, "id", "Lio/ray/api/id/PlacementGroupId;"); java_placement_group_creation_options_class = LoadClass(env, "io/ray/api/options/PlacementGroupCreationOptions"); @@ -308,7 +311,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_placement_group_creation_options_bundles = env->GetFieldID( java_placement_group_creation_options_class, "bundles", "Ljava/util/List;"); java_placement_group_creation_options_strategy = - env->GetFieldID(java_placement_group_creation_options_class, "strategy", + env->GetFieldID(java_placement_group_creation_options_class, + "strategy", "Lio/ray/api/placementgroup/PlacementStrategy;"); java_placement_group_creation_options_strategy_value = env->GetMethodID( java_placement_group_creation_options_strategy_class, "value", "()I"); @@ -318,7 +322,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_actor_creation_options_name = env->GetFieldID(java_actor_creation_options_class, "name", "Ljava/lang/String;"); java_actor_creation_options_lifetime = - env->GetFieldID(java_actor_creation_options_class, "lifetime", + env->GetFieldID(java_actor_creation_options_class, + "lifetime", "Lio/ray/api/options/ActorLifetime;"); java_actor_creation_options_max_restarts = env->GetFieldID(java_actor_creation_options_class, "maxRestarts", "I"); @@ -327,7 +332,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_actor_creation_options_max_concurrency = env->GetFieldID(java_actor_creation_options_class, "maxConcurrency", "I"); java_actor_creation_options_group = - env->GetFieldID(java_actor_creation_options_class, "group", + env->GetFieldID(java_actor_creation_options_class, + "group", "Lio/ray/api/placementgroup/PlacementGroup;"); java_actor_creation_options_bundle_index = env->GetFieldID(java_actor_creation_options_class, "bundleIndex", "I"); @@ -375,7 +381,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_task_executor_parse_function_arguments = env->GetMethodID( java_task_executor_class, "checkByteBufferArguments", "(Ljava/util/List;)[Z"); java_task_executor_execute = - env->GetMethodID(java_task_executor_class, "execute", + env->GetMethodID(java_task_executor_class, + "execute", "(Ljava/util/List;Ljava/util/List;)Ljava/util/List;"); java_native_task_executor_class = LoadClass(env, "io/ray/runtime/task/NativeTaskExecutor"); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 86a018d20..2e8b82b51 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -284,9 +284,13 @@ extern JavaVM *jvm; if (throwable) { \ jstring java_file_name = env->NewStringUTF(__FILE__); \ jstring java_function = env->NewStringUTF(__func__); \ - jobject java_error_message = env->CallStaticObjectMethod( \ - java_jni_exception_util_class, java_jni_exception_util_get_stack_trace, \ - java_file_name, __LINE__, java_function, throwable); \ + jobject java_error_message = \ + env->CallStaticObjectMethod(java_jni_exception_util_class, \ + java_jni_exception_util_get_stack_trace, \ + java_file_name, \ + __LINE__, \ + java_function, \ + throwable); \ std::string error_message = \ JavaStringToNativeString(env, static_cast(java_error_message)); \ env->DeleteLocalRef(throwable); \ @@ -342,8 +346,8 @@ inline std::string JavaByteArrayToNativeString(JNIEnv *env, const jbyteArray &by template inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) { std::string id_str(ID::Size(), 0); - env->GetByteArrayRegion(bytes, 0, ID::Size(), - reinterpret_cast(&id_str.front())); + env->GetByteArrayRegion( + bytes, 0, ID::Size(), reinterpret_cast(&id_str.front())); auto arr_size = static_cast(env->GetArrayLength(bytes)); RAY_CHECK(arr_size == ID::Size()) << "ID length should be " << ID::Size() << " instead of " << arr_size; @@ -354,8 +358,8 @@ inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) { template inline jbyteArray IdToJavaByteArray(JNIEnv *env, const ID &id) { jbyteArray array = env->NewByteArray(ID::Size()); - env->SetByteArrayRegion(array, 0, ID::Size(), - reinterpret_cast(id.Data())); + env->SetByteArrayRegion( + array, 0, ID::Size(), reinterpret_cast(id.Data())); return array; } @@ -369,8 +373,8 @@ inline jobject IdToJavaByteBuffer(JNIEnv *env, const ID &id) { /// Convert C++ String to a Java ByteArray. inline jbyteArray NativeStringToJavaByteArray(JNIEnv *env, const std::string &str) { jbyteArray array = env->NewByteArray(str.size()); - env->SetByteArrayRegion(array, 0, str.size(), - reinterpret_cast(str.c_str())); + env->SetByteArrayRegion( + array, 0, str.size(), reinterpret_cast(str.c_str())); return array; } @@ -385,7 +389,9 @@ inline std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) { /// Convert a Java List to C++ std::vector. template inline void JavaListToNativeVector( - JNIEnv *env, jobject java_list, std::vector *native_vector, + JNIEnv *env, + jobject java_list, + std::vector *native_vector, std::function element_converter) { int size = env->CallIntMethod(java_list, java_list_size); RAY_CHECK_JAVA_EXCEPTION(env); @@ -399,7 +405,8 @@ inline void JavaListToNativeVector( } /// Convert a Java List to C++ std::vector. -inline void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list, +inline void JavaStringListToNativeStringVector(JNIEnv *env, + jobject java_list, std::vector *native_vector) { JavaListToNativeVector( env, java_list, native_vector, [](JNIEnv *env, jobject jstr) { @@ -408,7 +415,8 @@ inline void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list, } /// Convert a Java long array to C++ std::vector. -inline void JavaLongArrayToNativeLongVector(JNIEnv *env, jlongArray long_array, +inline void JavaLongArrayToNativeLongVector(JNIEnv *env, + jlongArray long_array, std::vector *native_vector) { jlong *long_array_ptr = env->GetLongArrayElements(long_array, nullptr); jsize vec_size = env->GetArrayLength(long_array); @@ -419,7 +427,8 @@ inline void JavaLongArrayToNativeLongVector(JNIEnv *env, jlongArray long_array, } /// Convert a Java double array to C++ std::vector. -inline void JavaDoubleArrayToNativeDoubleVector(JNIEnv *env, jdoubleArray double_array, +inline void JavaDoubleArrayToNativeDoubleVector(JNIEnv *env, + jdoubleArray double_array, std::vector *native_vector) { jdouble *double_array_ptr = env->GetDoubleArrayElements(double_array, nullptr); jsize vec_size = env->GetArrayLength(double_array); @@ -432,11 +441,12 @@ inline void JavaDoubleArrayToNativeDoubleVector(JNIEnv *env, jdoubleArray double /// Convert a C++ std::vector to a Java List. template inline jobject NativeVectorToJavaList( - JNIEnv *env, const std::vector &native_vector, + JNIEnv *env, + const std::vector &native_vector, std::function element_converter) { - jobject java_list = - env->NewObject(java_array_list_class, java_array_list_init_with_capacity, - (jint)native_vector.size()); + jobject java_list = env->NewObject(java_array_list_class, + java_array_list_init_with_capacity, + (jint)native_vector.size()); RAY_CHECK_JAVA_EXCEPTION(env); for (auto it = native_vector.begin(); it != native_vector.end(); ++it) { auto element = element_converter(env, *it); @@ -451,8 +461,9 @@ inline jobject NativeVectorToJavaList( inline jobject NativeStringVectorToJavaStringList( JNIEnv *env, const std::vector &native_vector) { return NativeVectorToJavaList( - env, native_vector, - [](JNIEnv *env, const std::string &str) { return env->NewStringUTF(str.c_str()); }); + env, native_vector, [](JNIEnv *env, const std::string &str) { + return env->NewStringUTF(str.c_str()); + }); } template @@ -466,7 +477,8 @@ inline jobject NativeIdVectorToJavaByteArrayList(JNIEnv *env, /// Convert a Java Map to a C++ std::unordered_map template inline std::unordered_map JavaMapToNativeMap( - JNIEnv *env, jobject java_map, + JNIEnv *env, + jobject java_map, const std::function &key_converter, const std::function &value_converter) { std::unordered_map native_map; @@ -500,7 +512,8 @@ inline std::unordered_map JavaMapToNativeMap( /// Convert a C++ std::unordered_map to a Java Map template inline jobject NativeMapToJavaMap( - JNIEnv *env, const std::unordered_map &native_map, + JNIEnv *env, + const std::unordered_map &native_map, const std::function &key_converter, const std::function &value_converter) { jobject java_map = env->NewObject(java_hash_map_class, java_hash_map_init); @@ -524,7 +537,9 @@ inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env, } jbyteArray java_byte_array = env->NewByteArray(buffer->Size()); if (buffer->Size() > 0) { - env->SetByteArrayRegion(java_byte_array, 0, buffer->Size(), + env->SetByteArrayRegion(java_byte_array, + 0, + buffer->Size(), reinterpret_cast(buffer->Data())); } return java_byte_array; @@ -581,7 +596,9 @@ inline jobject NativeRayObjectToJavaNativeRayObject( auto java_data = NativeBufferToJavaByteArray(env, rayObject->GetData()); auto java_metadata = NativeBufferToJavaByteArray(env, rayObject->GetMetadata()); auto java_obj = env->NewObject(java_native_ray_object_class, - java_native_ray_object_init, java_data, java_metadata); + java_native_ray_object_init, + java_data, + java_metadata); RAY_CHECK_JAVA_EXCEPTION(env); env->DeleteLocalRef(java_metadata); env->DeleteLocalRef(java_data); @@ -601,8 +618,10 @@ inline jobject NativeRayFunctionDescriptorToJavaStringList( FunctionDescriptorType::kPythonFunctionDescriptor) { auto typed_descriptor = function_descriptor->As(); std::vector function_descriptor_list = { - typed_descriptor->ModuleName(), typed_descriptor->ClassName(), - typed_descriptor->FunctionName(), typed_descriptor->FunctionHash()}; + typed_descriptor->ModuleName(), + typed_descriptor->ClassName(), + typed_descriptor->FunctionName(), + typed_descriptor->FunctionHash()}; return NativeStringVectorToJavaStringList(env, function_descriptor_list); } RAY_LOG(FATAL) << "Unknown function descriptor type: " << function_descriptor->Type(); @@ -636,7 +655,7 @@ inline std::shared_ptr SerializeActorCreationException( env->CallObjectMethod(creation_exception, java_ray_exception_to_bytes)); int len = env->GetArrayLength(exception_jbyte_array); auto buf = std::make_shared(len); - env->GetByteArrayRegion(exception_jbyte_array, 0, len, - reinterpret_cast(buf->Data())); + env->GetByteArrayRegion( + exception_jbyte_array, 0, len, reinterpret_cast(buf->Data())); return buf; } diff --git a/src/ray/core_worker/object_recovery_manager.cc b/src/ray/core_worker/object_recovery_manager.cc index 790eeba43..3c7f8470a 100644 --- a/src/ray/core_worker/object_recovery_manager.cc +++ b/src/ray/core_worker/object_recovery_manager.cc @@ -92,7 +92,8 @@ void ObjectRecoveryManager::PinOrReconstructObject( } void ObjectRecoveryManager::PinExistingObjectCopy( - const ObjectID &object_id, const rpc::Address &raylet_address, + const ObjectID &object_id, + const rpc::Address &raylet_address, const std::vector &other_locations) { // If a copy still exists, pin the object by sending a // PinObjectIDs RPC. @@ -109,14 +110,16 @@ void ObjectRecoveryManager::PinExistingObjectCopy( if (client_it == remote_object_pinning_clients_.end()) { RAY_LOG(DEBUG) << "Connecting to raylet " << node_id; client_it = remote_object_pinning_clients_ - .emplace(node_id, client_factory_(raylet_address.ip_address(), - raylet_address.port())) + .emplace(node_id, + client_factory_(raylet_address.ip_address(), + raylet_address.port())) .first; } client = client_it->second; } - client->PinObjectIDs(rpc_address_, {object_id}, + client->PinObjectIDs(rpc_address_, + {object_id}, [this, object_id, other_locations, node_id]( const Status &status, const rpc::PinObjectIDsReply &reply) { if (status.ok()) { @@ -146,7 +149,8 @@ void ObjectRecoveryManager::ReconstructObject(const ObjectID &object_id) { rpc::ErrorType::OBJECT_UNRECONSTRUCTABLE_LINEAGE_EVICTED, /*pin_object=*/true); } else { - recovery_failure_callback_(object_id, rpc::ErrorType::OBJECT_LOST, + recovery_failure_callback_(object_id, + rpc::ErrorType::OBJECT_LOST, /*pin_object=*/true); } return; @@ -169,7 +173,8 @@ void ObjectRecoveryManager::ReconstructObject(const ObjectID &object_id) { // worker, or if there was a bug in reconstruction that caused us to GC // the dependency ref. // We do not pin the dependency because we may not be the owner. - recovery_failure_callback_(dep, rpc::ErrorType::OBJECT_UNRECONSTRUCTABLE, + recovery_failure_callback_(dep, + rpc::ErrorType::OBJECT_UNRECONSTRUCTABLE, /*pin_object=*/false); } } @@ -177,7 +182,8 @@ void ObjectRecoveryManager::ReconstructObject(const ObjectID &object_id) { RAY_LOG(INFO) << "Failed to reconstruct object " << object_id << " because lineage has already been deleted"; recovery_failure_callback_( - object_id, rpc::ErrorType::OBJECT_UNRECONSTRUCTABLE_MAX_ATTEMPTS_EXCEEDED, + object_id, + rpc::ErrorType::OBJECT_UNRECONSTRUCTABLE_MAX_ATTEMPTS_EXCEEDED, /*pin_object=*/true); } } diff --git a/src/ray/core_worker/object_recovery_manager.h b/src/ray/core_worker/object_recovery_manager.h index 6e58b7cb9..b224e972d 100644 --- a/src/ray/core_worker/object_recovery_manager.h +++ b/src/ray/core_worker/object_recovery_manager.h @@ -34,22 +34,22 @@ typedef std::function +typedef std::function ObjectRecoveryFailureCallback; class ObjectRecoveryManager { public: - ObjectRecoveryManager(const rpc::Address &rpc_address, - ObjectPinningClientFactoryFn client_factory, - std::shared_ptr local_object_pinning_client, - std::function - object_lookup, - std::shared_ptr task_resubmitter, - std::shared_ptr reference_counter, - std::shared_ptr in_memory_store, - const ObjectRecoveryFailureCallback &recovery_failure_callback) + ObjectRecoveryManager( + const rpc::Address &rpc_address, + ObjectPinningClientFactoryFn client_factory, + std::shared_ptr local_object_pinning_client, + std::function object_lookup, + std::shared_ptr task_resubmitter, + std::shared_ptr reference_counter, + std::shared_ptr in_memory_store, + const ObjectRecoveryFailureCallback &recovery_failure_callback) : task_resubmitter_(task_resubmitter), reference_counter_(reference_counter), rpc_address_(rpc_address), diff --git a/src/ray/core_worker/profiling.cc b/src/ray/core_worker/profiling.cc index 7402ec563..1d9b4d306 100644 --- a/src/ray/core_worker/profiling.cc +++ b/src/ray/core_worker/profiling.cc @@ -28,7 +28,8 @@ ProfileEvent::ProfileEvent(const std::shared_ptr &profiler, rpc_event_.set_start_time(absl::GetCurrentTimeNanos() / 1e9); } -Profiler::Profiler(WorkerContext &worker_context, const std::string &node_ip_address, +Profiler::Profiler(WorkerContext &worker_context, + const std::string &node_ip_address, instrumented_io_context &io_service, const std::shared_ptr &gcs_client) : io_service_(io_service), @@ -39,7 +40,8 @@ Profiler::Profiler(WorkerContext &worker_context, const std::string &node_ip_add rpc_profile_data_->set_component_id(worker_context.GetWorkerID().Binary()); rpc_profile_data_->set_node_ip_address(node_ip_address); periodical_runner_.RunFnPeriodically( - [this] { FlushEvents(); }, 1000, + [this] { FlushEvents(); }, + 1000, "CoreWorker.deadline_timer.flush_profiling_events"); } diff --git a/src/ray/core_worker/profiling.h b/src/ray/core_worker/profiling.h index 42cbb8f9e..91b615c54 100644 --- a/src/ray/core_worker/profiling.h +++ b/src/ray/core_worker/profiling.h @@ -29,7 +29,8 @@ namespace worker { class Profiler { public: - Profiler(WorkerContext &worker_context, const std::string &node_ip_address, + Profiler(WorkerContext &worker_context, + const std::string &node_ip_address, instrumented_io_context &io_service, const std::shared_ptr &gcs_client); diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index 313f75255..47a7985e6 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -88,8 +88,8 @@ bool ReferenceCounter::AddBorrowedObject(const ObjectID &object_id, const rpc::Address &owner_address, bool foreign_owner_already_monitoring) { absl::MutexLock lock(&mutex_); - return AddBorrowedObjectInternal(object_id, outer_id, owner_address, - foreign_owner_already_monitoring); + return AddBorrowedObjectInternal( + object_id, outer_id, owner_address, foreign_owner_already_monitoring); } bool ReferenceCounter::AddBorrowedObjectInternal(const ObjectID &object_id, @@ -174,7 +174,8 @@ void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, const std::vector &inner_ids, const rpc::Address &owner_address, const std::string &call_site, - const int64_t object_size, bool is_reconstructable, + const int64_t object_size, + bool is_reconstructable, bool add_local_ref, const absl::optional &pinned_at_raylet_id) { RAY_LOG(DEBUG) << "Adding owned object " << object_id; @@ -187,8 +188,12 @@ void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, // TODO(swang): Objects that are not reconstructable should not increment // their arguments' lineage ref counts. auto it = object_id_refs_ - .emplace(object_id, Reference(owner_address, call_site, object_size, - is_reconstructable, pinned_at_raylet_id)) + .emplace(object_id, + Reference(owner_address, + call_site, + object_size, + is_reconstructable, + pinned_at_raylet_id)) .first; if (!inner_ids.empty()) { // Mark that this object ID contains other inner IDs. Then, we will not GC @@ -301,7 +306,8 @@ void ReferenceCounter::RemoveLocalReferenceInternal(const ObjectID &object_id, void ReferenceCounter::UpdateSubmittedTaskReferences( const std::vector return_ids, const std::vector &argument_ids_to_add, - const std::vector &argument_ids_to_remove, std::vector *deleted) { + const std::vector &argument_ids_to_remove, + std::vector *deleted) { absl::MutexLock lock(&mutex_); for (const auto &return_id : return_ids) { UpdateObjectPendingCreation(return_id, true); @@ -325,8 +331,8 @@ void ReferenceCounter::UpdateSubmittedTaskReferences( } // Release the submitted task ref and the lineage ref for any argument IDs // whose values were inlined. - RemoveSubmittedTaskReferences(argument_ids_to_remove, /*release_lineage=*/true, - deleted); + RemoveSubmittedTaskReferences( + argument_ids_to_remove, /*release_lineage=*/true, deleted); } void ReferenceCounter::UpdateResubmittedTaskReferences( @@ -347,9 +353,12 @@ void ReferenceCounter::UpdateResubmittedTaskReferences( } void ReferenceCounter::UpdateFinishedTaskReferences( - const std::vector return_ids, const std::vector &argument_ids, - bool release_lineage, const rpc::Address &worker_addr, - const ReferenceTableProto &borrowed_refs, std::vector *deleted) { + const std::vector return_ids, + const std::vector &argument_ids, + bool release_lineage, + const rpc::Address &worker_addr, + const ReferenceTableProto &borrowed_refs, + std::vector *deleted) { absl::MutexLock lock(&mutex_); for (const auto &return_id : return_ids) { UpdateObjectPendingCreation(return_id, false); @@ -409,7 +418,8 @@ int64_t ReferenceCounter::ReleaseLineageReferences(ReferenceTable::iterator ref) } void ReferenceCounter::RemoveSubmittedTaskReferences( - const std::vector &argument_ids, bool release_lineage, + const std::vector &argument_ids, + bool release_lineage, std::vector *deleted) { for (const ObjectID &argument_id : argument_ids) { RAY_LOG(DEBUG) << "Releasing ref for submitted task argument " << argument_id; @@ -690,7 +700,8 @@ void ReferenceCounter::UpdateObjectPinnedAtRaylet(const ObjectID &object_id, } bool ReferenceCounter::IsPlasmaObjectPinnedOrSpilled(const ObjectID &object_id, - bool *owned_by_us, NodeID *pinned_at, + bool *owned_by_us, + NodeID *pinned_at, bool *spilled) const { absl::MutexLock lock(&mutex_); auto it = object_id_refs_.find(object_id); @@ -740,7 +751,8 @@ ReferenceCounter::GetAllReferenceCounts() const { void ReferenceCounter::PopAndClearLocalBorrowers( const std::vector &borrowed_ids, - ReferenceCounter::ReferenceTableProto *proto, std::vector *deleted) { + ReferenceCounter::ReferenceTableProto *proto, + std::vector *deleted) { absl::MutexLock lock(&mutex_); ReferenceTable borrowed_refs; for (const auto &borrowed_id : borrowed_ids) { @@ -870,7 +882,8 @@ void ReferenceCounter::MergeRemoteBorrowers(const ObjectID &object_id, for (const auto &contained_in_borrowed_id : borrower_it->second.contained_in_borrowed_ids) { RAY_CHECK(borrower_ref.owner_address); - AddBorrowedObjectInternal(object_id, contained_in_borrowed_id, + AddBorrowedObjectInternal(object_id, + contained_in_borrowed_id, *borrower_ref.owner_address, /*foreign_owner_already_monitoring=*/false); } @@ -890,8 +903,8 @@ void ReferenceCounter::MergeRemoteBorrowers(const ObjectID &object_id, // If the borrower stored this object ID inside another object ID that it did // not own, then mark that the object ID is nested inside another. for (const auto &stored_in_object : borrower_ref.stored_in_objects) { - AddNestedObjectIdsInternal(stored_in_object.first, {object_id}, - stored_in_object.second); + AddNestedObjectIdsInternal( + stored_in_object.first, {object_id}, stored_in_object.second); } // Recursively merge any references that were contained in this object, to @@ -903,7 +916,8 @@ void ReferenceCounter::MergeRemoteBorrowers(const ObjectID &object_id, } void ReferenceCounter::CleanupBorrowersOnRefRemoved( - const ReferenceTable &new_borrower_refs, const ObjectID &object_id, + const ReferenceTable &new_borrower_refs, + const ObjectID &object_id, const rpc::WorkerAddress &borrower_addr) { absl::MutexLock lock(&mutex_); // Merge in any new borrowers that the previous borrower learned of. @@ -933,20 +947,21 @@ void ReferenceCounter::WaitForRefRemoved(const ReferenceTable::iterator &ref_it, request->set_subscriber_worker_id(rpc_address_.ToProto().worker_id()); // If the message is published, this callback will be invoked. - const auto message_published_callback = [this, addr, - object_id](const rpc::PubMessage &msg) { - RAY_CHECK(msg.has_worker_ref_removed_message()); - const ReferenceTable new_borrower_refs = - ReferenceTableFromProto(msg.worker_ref_removed_message().borrowed_refs()); - RAY_LOG(DEBUG) << "WaitForRefRemoved returned for " << object_id - << ", dest=" << addr.worker_id; + const auto message_published_callback = + [this, addr, object_id](const rpc::PubMessage &msg) { + RAY_CHECK(msg.has_worker_ref_removed_message()); + const ReferenceTable new_borrower_refs = + ReferenceTableFromProto(msg.worker_ref_removed_message().borrowed_refs()); + RAY_LOG(DEBUG) << "WaitForRefRemoved returned for " << object_id + << ", dest=" << addr.worker_id; - CleanupBorrowersOnRefRemoved(new_borrower_refs, object_id, addr); - // Unsubscribe the object once the message is published. - RAY_CHECK( - object_info_subscriber_->Unsubscribe(rpc::ChannelType::WORKER_REF_REMOVED_CHANNEL, - addr.ToProto(), object_id.Binary())); - }; + CleanupBorrowersOnRefRemoved(new_borrower_refs, object_id, addr); + // Unsubscribe the object once the message is published. + RAY_CHECK(object_info_subscriber_->Unsubscribe( + rpc::ChannelType::WORKER_REF_REMOVED_CHANNEL, + addr.ToProto(), + object_id.Binary())); + }; // If the borrower is failed, this callback will be called. const auto publisher_failed_callback = [this, addr](const std::string &object_id_binary, @@ -959,10 +974,14 @@ void ReferenceCounter::WaitForRefRemoved(const ReferenceTable::iterator &ref_it, CleanupBorrowersOnRefRemoved({}, object_id, addr); }; - RAY_CHECK(object_info_subscriber_->Subscribe( - std::move(sub_message), rpc::ChannelType::WORKER_REF_REMOVED_CHANNEL, - addr.ToProto(), object_id.Binary(), /*subscribe_done_callback=*/nullptr, - message_published_callback, publisher_failed_callback)); + RAY_CHECK( + object_info_subscriber_->Subscribe(std::move(sub_message), + rpc::ChannelType::WORKER_REF_REMOVED_CHANNEL, + addr.ToProto(), + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + message_published_callback, + publisher_failed_callback)); } void ReferenceCounter::AddNestedObjectIds(const ObjectID &object_id, @@ -973,7 +992,8 @@ void ReferenceCounter::AddNestedObjectIds(const ObjectID &object_id, } void ReferenceCounter::AddNestedObjectIdsInternal( - const ObjectID &object_id, const std::vector &inner_ids, + const ObjectID &object_id, + const std::vector &inner_ids, const rpc::WorkerAddress &owner_address) { RAY_CHECK(!owner_address.worker_id.IsNil()); auto it = object_id_refs_.find(object_id); @@ -1039,7 +1059,8 @@ void ReferenceCounter::HandleRefRemoved(const ObjectID &object_id) { } ReferenceTable borrowed_refs; RAY_UNUSED(GetAndClearLocalBorrowersInternal(object_id, - /*for_ref_removed=*/true, &borrowed_refs)); + /*for_ref_removed=*/true, + &borrowed_refs)); for (const auto &pair : borrowed_refs) { RAY_LOG(DEBUG) << pair.first << " has " << pair.second.borrowers.size() << " borrowers, stored in " << pair.second.stored_in_objects.size(); @@ -1060,7 +1081,8 @@ void ReferenceCounter::HandleRefRemoved(const ObjectID &object_id) { } void ReferenceCounter::SetRefRemovedCallback( - const ObjectID &object_id, const ObjectID &contained_in_id, + const ObjectID &object_id, + const ObjectID &contained_in_id, const rpc::Address &owner_address, const ReferenceCounter::ReferenceRemovedCallback &ref_removed_callback) { absl::MutexLock lock(&mutex_); @@ -1191,7 +1213,8 @@ size_t ReferenceCounter::GetObjectSize(const ObjectID &object_id) const { bool ReferenceCounter::HandleObjectSpilled(const ObjectID &object_id, const std::string spilled_url, - const NodeID &spilled_node_id, int64_t size) { + const NodeID &spilled_node_id, + int64_t size) { absl::MutexLock lock(&mutex_); auto it = object_id_refs_.find(object_id); if (it == object_id_refs_.end()) { diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 71b4e138c..724903165 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -36,13 +36,18 @@ class ReferenceCounterInterface { public: virtual void AddLocalReference(const ObjectID &object_id, const std::string &call_site) = 0; - virtual bool AddBorrowedObject(const ObjectID &object_id, const ObjectID &outer_id, + virtual bool AddBorrowedObject(const ObjectID &object_id, + const ObjectID &outer_id, const rpc::Address &owner_address, bool foreign_owner_already_monitoring = false) = 0; virtual void AddOwnedObject( - const ObjectID &object_id, const std::vector &contained_ids, - const rpc::Address &owner_address, const std::string &call_site, - const int64_t object_size, bool is_reconstructable, bool add_local_ref, + const ObjectID &object_id, + const std::vector &contained_ids, + const rpc::Address &owner_address, + const std::string &call_site, + const int64_t object_size, + bool is_reconstructable, + bool add_local_ref, const absl::optional &pinned_at_raylet_id = absl::optional()) = 0; virtual bool SetDeleteCallback( const ObjectID &object_id, @@ -146,7 +151,8 @@ class ReferenceCounter : public ReferenceCounterInterface, /// \param[out] deleted The object IDs whos reference counts reached zero. void UpdateFinishedTaskReferences(const std::vector return_ids, const std::vector &argument_ids, - bool release_lineage, const rpc::Address &worker_addr, + bool release_lineage, + const rpc::Address &worker_addr, const ReferenceTableProto &borrowed_refs, std::vector *deleted) LOCKS_EXCLUDED(mutex_); @@ -174,12 +180,15 @@ class ReferenceCounter : public ReferenceCounterInterface, /// corresponding ObjectRef has been returned to the language frontend. /// \param[in] pinned_at_raylet_id The primary location for the object, if it /// is already known. This is only used for ray.put calls. - void AddOwnedObject( - const ObjectID &object_id, const std::vector &contained_ids, - const rpc::Address &owner_address, const std::string &call_site, - const int64_t object_size, bool is_reconstructable, bool add_local_ref, - const absl::optional &pinned_at_raylet_id = absl::optional()) - LOCKS_EXCLUDED(mutex_); + void AddOwnedObject(const ObjectID &object_id, + const std::vector &contained_ids, + const rpc::Address &owner_address, + const std::string &call_site, + const int64_t object_size, + bool is_reconstructable, + bool add_local_ref, + const absl::optional &pinned_at_raylet_id = + absl::optional()) LOCKS_EXCLUDED(mutex_); /// Update the size of the object. /// @@ -197,7 +206,8 @@ class ReferenceCounter : public ReferenceCounterInterface, /// out-of-band. /// task ID (for non-actors) or the actor ID of the owner. /// \param[in] owner_address The owner's address. - bool AddBorrowedObject(const ObjectID &object_id, const ObjectID &outer_id, + bool AddBorrowedObject(const ObjectID &object_id, + const ObjectID &outer_id, const rpc::Address &owner_address, bool foreign_owner_already_monitoring = false) LOCKS_EXCLUDED(mutex_); @@ -251,7 +261,8 @@ class ReferenceCounter : public ReferenceCounterInterface, /// \param[in] owner_address The owner of object_id's address. /// \param[in] ref_removed_callback The callback to call when we are no /// longer borrowing the object. - void SetRefRemovedCallback(const ObjectID &object_id, const ObjectID &contained_in_id, + void SetRefRemovedCallback(const ObjectID &object_id, + const ObjectID &contained_in_id, const rpc::Address &owner_address, const ReferenceRemovedCallback &ref_removed_callback) LOCKS_EXCLUDED(mutex_); @@ -344,9 +355,10 @@ class ReferenceCounter : public ReferenceCounterInterface, /// \param[out] spilled Whether this object has been spilled. /// pinned. Set to nil if the object is not pinned. /// \return True if the reference exists, false otherwise. - bool IsPlasmaObjectPinnedOrSpilled(const ObjectID &object_id, bool *owned_by_us, - NodeID *pinned_at, bool *spilled) const - LOCKS_EXCLUDED(mutex_); + bool IsPlasmaObjectPinnedOrSpilled(const ObjectID &object_id, + bool *owned_by_us, + NodeID *pinned_at, + bool *spilled) const LOCKS_EXCLUDED(mutex_); /// Get and reset the objects that were pinned or spilled on the given node. /// This method should be called upon a node failure, to trigger @@ -434,8 +446,10 @@ class ReferenceCounter : public ReferenceCounterInterface, /// \param[in] spilled_node_id The ID of the node on which the object was spilled. /// \param[in] size The size of the object. /// \return True if the reference exists and is in scope, false otherwise. - bool HandleObjectSpilled(const ObjectID &object_id, const std::string spilled_url, - const NodeID &spilled_node_id, int64_t size); + bool HandleObjectSpilled(const ObjectID &object_id, + const std::string spilled_url, + const NodeID &spilled_node_id, + int64_t size); /// Get locality data for object. This is used by the leasing policy to implement /// locality-aware leasing. @@ -486,8 +500,10 @@ class ReferenceCounter : public ReferenceCounterInterface, Reference(std::string call_site, const int64_t object_size) : call_site(call_site), object_size(object_size) {} /// Constructor for a reference that we created. - Reference(const rpc::Address &owner_address, std::string call_site, - const int64_t object_size, bool is_reconstructable, + Reference(const rpc::Address &owner_address, + std::string call_site, + const int64_t object_size, + bool is_reconstructable, const absl::optional &pinned_at_raylet_id) : call_site(call_site), object_size(object_size), @@ -682,7 +698,8 @@ class ReferenceCounter : public ReferenceCounterInterface, /// inlined dependencies are inlined or when the task finishes for plasma /// dependencies. void RemoveSubmittedTaskReferences(const std::vector &argument_ids, - bool release_lineage, std::vector *deleted) + bool release_lineage, + std::vector *deleted) EXCLUSIVE_LOCKS_REQUIRED(mutex_); /// Helper method to mark that this ObjectID contains another ObjectID(s). @@ -718,7 +735,8 @@ class ReferenceCounter : public ReferenceCounterInterface, /// that contained it. We don't need this anymore because we already marked /// that the borrowed ID contained another ID in the returned /// borrowed_refs. - bool GetAndClearLocalBorrowersInternal(const ObjectID &object_id, bool for_ref_removed, + bool GetAndClearLocalBorrowersInternal(const ObjectID &object_id, + bool for_ref_removed, ReferenceTable *borrowed_refs) EXCLUSIVE_LOCKS_REQUIRED(mutex_); @@ -760,7 +778,8 @@ class ReferenceCounter : public ReferenceCounterInterface, /// \param[in] foreign_owner_already_monitoring Whether to set the bit that an /// externally assigned owner is monitoring the lifetime of this /// object. This is the case for `ray.put(..., _owner=ZZZ)`. - bool AddBorrowedObjectInternal(const ObjectID &object_id, const ObjectID &outer_id, + bool AddBorrowedObjectInternal(const ObjectID &object_id, + const ObjectID &outer_id, const rpc::Address &owner_address, bool foreign_owner_already_monitoring) EXCLUSIVE_LOCKS_REQUIRED(mutex_); diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index 254f9a8ae..003f67b42 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -40,8 +40,10 @@ class ReferenceCountTest : public ::testing::Test { rpc::Address addr; publisher_ = std::make_shared(); subscriber_ = std::make_shared(); - rc = std::make_unique(addr, publisher_.get(), subscriber_.get(), - [](const NodeID &node_id) { return true; }); + rc = std::make_unique( + addr, publisher_.get(), subscriber_.get(), [](const NodeID &node_id) { + return true; + }); } virtual void TearDown() { @@ -65,7 +67,9 @@ class ReferenceCountLineageEnabledTest : public ::testing::Test { publisher_ = std::make_shared(); subscriber_ = std::make_shared(); rc = std::make_unique( - addr, publisher_.get(), subscriber_.get(), + addr, + publisher_.get(), + subscriber_.get(), [](const NodeID &node_id) { return true; }, /*lineage_pinning_enabled=*/true); } @@ -107,7 +111,8 @@ static std::string GenerateID(UniqueID publisher_id, UniqueID subscriber_id) { class MockCoreWorkerClientInterface : public rpc::CoreWorkerClientInterface { public: ~MockCoreWorkerClientInterface() = default; - virtual void WaitForRefRemoved(const ObjectID object_id, const ObjectID contained_in_id, + virtual void WaitForRefRemoved(const ObjectID object_id, + const ObjectID contained_in_id, rpc::Address owner_address) = 0; }; @@ -120,7 +125,8 @@ class MockDistributedSubscriber : public pubsub::SubscriberInterface { pubsub::pub_internal::SubscriptionIndex *directory, SubscriptionCallbackMap *subscription_callback_map, SubscriptionFailureCallbackMap *subscription_failure_callback_map, - WorkerID subscriber_id, PublisherFactoryFn client_factory) + WorkerID subscriber_id, + PublisherFactoryFn client_factory) : directory_(directory), subscription_callback_map_(subscription_callback_map), subscription_failure_callback_map_(subscription_failure_callback_map), @@ -131,7 +137,8 @@ class MockDistributedSubscriber : public pubsub::SubscriberInterface { bool Subscribe( const std::unique_ptr sub_message, - const rpc::ChannelType channel_type, const rpc::Address &publisher_address, + const rpc::ChannelType channel_type, + const rpc::Address &publisher_address, const std::string &key_id_binary, pubsub::SubscribeDoneCallback subscribe_done_callback, pubsub::SubscriptionItemCallback subscription_callback, @@ -173,7 +180,8 @@ class MockDistributedSubscriber : public pubsub::SubscriberInterface { bool SubscribeChannel( const std::unique_ptr sub_message, - const rpc::ChannelType channel_type, const rpc::Address &publisher_address, + const rpc::ChannelType channel_type, + const rpc::Address &publisher_address, pubsub::SubscribeDoneCallback subscribe_done_callback, pubsub::SubscriptionItemCallback subscription_callback, pubsub::SubscriptionFailureCallback subscription_failure_callback) override { @@ -278,15 +286,23 @@ class MockWorkerClient : public MockCoreWorkerClientInterface { MockWorkerClient(const std::string &addr, PublisherFactoryFn client_factory = nullptr) : address_(CreateRandomAddress(addr)), publisher_(std::make_shared( - &directory, &subscription_callback_map, &subscription_failure_callback_map, + &directory, + &subscription_callback_map, + &subscription_failure_callback_map, WorkerID::FromBinary(address_.worker_id()))), subscriber_(std::make_shared( - &directory, &subscription_callback_map, &subscription_failure_callback_map, - WorkerID::FromBinary(address_.worker_id()), client_factory)), + &directory, + &subscription_callback_map, + &subscription_failure_callback_map, + WorkerID::FromBinary(address_.worker_id()), + client_factory)), rc_( - rpc::WorkerAddress(address_), publisher_.get(), subscriber_.get(), + rpc::WorkerAddress(address_), + publisher_.get(), + subscriber_.get(), [](const NodeID &node_id) { return true; }, - /*lineage_pinning_enabled=*/false, client_factory) {} + /*lineage_pinning_enabled=*/false, + client_factory) {} ~MockWorkerClient() override { if (!failed_) { @@ -294,15 +310,16 @@ class MockWorkerClient : public MockCoreWorkerClientInterface { } } - void WaitForRefRemoved(const ObjectID object_id, const ObjectID contained_in_id, + void WaitForRefRemoved(const ObjectID object_id, + const ObjectID contained_in_id, rpc::Address owner_address) override { auto r = num_requests_; auto borrower_callback = [=]() { auto ref_removed_callback = absl::bind_front(&ReferenceCounter::HandleRefRemoved, &rc_); - rc_.SetRefRemovedCallback(object_id, contained_in_id, owner_address, - ref_removed_callback); + rc_.SetRefRemovedCallback( + object_id, contained_in_id, owner_address, ref_removed_callback); }; borrower_callbacks_[r] = borrower_callback; @@ -351,17 +368,24 @@ class MockWorkerClient : public MockCoreWorkerClientInterface { } void PutWrappedId(const ObjectID outer_id, const ObjectID &inner_id) { - rc_.AddOwnedObject(outer_id, {inner_id}, address_, "", 0, false, + rc_.AddOwnedObject(outer_id, + {inner_id}, + address_, + "", + 0, + false, /*add_local_ref=*/true); } - void GetSerializedObjectId(const ObjectID outer_id, const ObjectID &inner_id, + void GetSerializedObjectId(const ObjectID outer_id, + const ObjectID &inner_id, const rpc::Address &owner_address) { rc_.AddLocalReference(inner_id, ""); rc_.AddBorrowedObject(inner_id, outer_id, owner_address); } - void ExecuteTaskWithArg(const ObjectID &arg_id, const ObjectID &inner_id, + void ExecuteTaskWithArg(const ObjectID &arg_id, + const ObjectID &inner_id, const rpc::Address &owner_address) { // Add a sentinel reference to keep the argument ID in scope even though // the frontend won't have a reference. @@ -380,7 +404,8 @@ class MockWorkerClient : public MockCoreWorkerClientInterface { } ReferenceCounter::ReferenceTableProto FinishExecutingTask( - const ObjectID &arg_id, const ObjectID &return_id, + const ObjectID &arg_id, + const ObjectID &return_id, const ObjectID *return_wrapped_id = nullptr, const rpc::WorkerAddress *owner_address = nullptr) { if (return_wrapped_id) { @@ -395,7 +420,8 @@ class MockWorkerClient : public MockCoreWorkerClientInterface { } void HandleSubmittedTaskFinished( - const ObjectID &return_id, const ObjectID &arg_id, + const ObjectID &return_id, + const ObjectID &arg_id, const std::unordered_map> &nested_return_ids = {}, const rpc::Address &borrower_address = empty_borrower, const ReferenceCounter::ReferenceTableProto &borrower_refs = empty_refs) { @@ -407,8 +433,8 @@ class MockWorkerClient : public MockCoreWorkerClientInterface { if (!arg_id.IsNil()) { arguments.push_back(arg_id); } - rc_.UpdateFinishedTaskReferences({return_id}, arguments, false, borrower_address, - borrower_refs, nullptr); + rc_.UpdateFinishedTaskReferences( + {return_id}, arguments, false, borrower_address, borrower_refs, nullptr); } WorkerID GetID() const { return WorkerID::FromBinary(address_.worker_id()); } @@ -477,16 +503,16 @@ TEST_F(ReferenceCountTest, TestBasic) { ASSERT_TRUE(rc->IsObjectPendingCreation(return_id2)); ASSERT_EQ(rc->NumObjectIDsInScope(), 4); - rc->UpdateFinishedTaskReferences({return_id1}, {id1}, false, empty_borrower, empty_refs, - &out); + rc->UpdateFinishedTaskReferences( + {return_id1}, {id1}, false, empty_borrower, empty_refs, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 4); ASSERT_EQ(out.size(), 0); - rc->UpdateFinishedTaskReferences({return_id2}, {id2}, false, empty_borrower, empty_refs, - &out); + rc->UpdateFinishedTaskReferences( + {return_id2}, {id2}, false, empty_borrower, empty_refs, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 3); ASSERT_EQ(out.size(), 1); - rc->UpdateFinishedTaskReferences({return_id2}, {id1}, false, empty_borrower, empty_refs, - &out); + rc->UpdateFinishedTaskReferences( + {return_id2}, {id1}, false, empty_borrower, empty_refs, &out); ASSERT_EQ(out.size(), 2); ASSERT_FALSE(rc->IsObjectPendingCreation(return_id1)); ASSERT_FALSE(rc->IsObjectPendingCreation(return_id2)); @@ -503,12 +529,12 @@ TEST_F(ReferenceCountTest, TestBasic) { rc->RemoveLocalReference(id1, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 2); ASSERT_EQ(out.size(), 0); - rc->UpdateFinishedTaskReferences({return_id1}, {id2}, false, empty_borrower, empty_refs, - &out); + rc->UpdateFinishedTaskReferences( + {return_id1}, {id2}, false, empty_borrower, empty_refs, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 2); ASSERT_EQ(out.size(), 0); - rc->UpdateFinishedTaskReferences({return_id1}, {id1}, false, empty_borrower, empty_refs, - &out); + rc->UpdateFinishedTaskReferences( + {return_id1}, {id1}, false, empty_borrower, empty_refs, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); ASSERT_EQ(out.size(), 1); rc->RemoveLocalReference(id2, &out); @@ -598,8 +624,14 @@ TEST_F(ReferenceCountTest, TestGetLocalityData) { // Owned object with defined object size and pinned node location should return valid // locality data. int64_t object_size = 100; - rc->AddOwnedObject(obj1, {}, address, "file2.py:42", object_size, false, - /*add_local_ref=*/true, absl::optional(node1)); + rc->AddOwnedObject(obj1, + {}, + address, + "file2.py:42", + object_size, + false, + /*add_local_ref=*/true, + absl::optional(node1)); auto locality_data_obj1 = rc->GetLocalityData(obj1); ASSERT_TRUE(locality_data_obj1.has_value()); ASSERT_EQ(locality_data_obj1->object_size, object_size); @@ -660,8 +692,14 @@ TEST_F(ReferenceCountTest, TestGetLocalityData) { // Fetching locality data for an object that doesn't have an object size defined // should return a null optional. - rc->AddOwnedObject(obj2, {}, address, "file2.py:43", -1, false, - /*add_local_ref=*/true, absl::optional(node2)); + rc->AddOwnedObject(obj2, + {}, + address, + "file2.py:43", + -1, + false, + /*add_local_ref=*/true, + absl::optional(node2)); auto locality_data_obj2_no_object_size = rc->GetLocalityData(obj2); ASSERT_FALSE(locality_data_obj2_no_object_size.has_value()); @@ -710,8 +748,10 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { auto publisher = std::make_shared(); auto subscriber = std::make_shared(); auto rc = std::shared_ptr( - new ReferenceCounter(rpc::WorkerAddress(rpc::Address()), publisher.get(), - subscriber.get(), [](const NodeID &node_id) { return true; })); + new ReferenceCounter(rpc::WorkerAddress(rpc::Address()), + publisher.get(), + subscriber.get(), + [](const NodeID &node_id) { return true; })); CoreWorkerMemoryStore store(rc); // Tests putting an object with no references is ignored. @@ -724,8 +764,12 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { ASSERT_EQ(store.Size(), 1); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); - RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1, ctx, - /*remove_after_get*/ true, &results)); + RAY_CHECK_OK(store.Get({id1}, + /*num_objects*/ 1, + /*timeout_ms*/ -1, + ctx, + /*remove_after_get*/ true, + &results)); ASSERT_EQ(results.size(), 1); ASSERT_EQ(store.Size(), 1); } @@ -778,8 +822,8 @@ TEST(DistributedReferenceCountTest, TestNoBorrow) { // The owner receives the borrower's reply and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id1, outer_id, {}, borrower->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id1, outer_id, {}, borrower->address_, borrower_refs); borrower->FlushBorrowerCallbacks(); // Check that owner's ref count is now 0 for all objects. ASSERT_FALSE(owner->rc_.HasReference(inner_id)); @@ -835,8 +879,8 @@ TEST(DistributedReferenceCountTest, TestSimpleBorrower) { // The owner receives the borrower's reply and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id1, outer_id, {}, borrower->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id1, outer_id, {}, borrower->address_, borrower_refs); borrower->FlushBorrowerCallbacks(); // Check that owner now has borrower in inner's borrowers list. ASSERT_TRUE(owner->rc_.HasReference(inner_id)); @@ -906,8 +950,8 @@ TEST(DistributedReferenceCountTest, TestSimpleBorrowerFailure) { // The owner receives the borrower's reply and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id1, outer_id, {}, borrower->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id1, outer_id, {}, borrower->address_, borrower_refs); borrower->FlushBorrowerCallbacks(); // Check that owner now has borrower in inner's borrowers list. ASSERT_TRUE(owner->rc_.HasReference(inner_id)); @@ -965,8 +1009,8 @@ TEST(DistributedReferenceCountTest, TestSimpleBorrowerReferenceRemoved) { // The owner receives the borrower's reply and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id, outer_id, {}, borrower->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id, outer_id, {}, borrower->address_, borrower_refs); // Check that owner now has borrower in inner's borrowers list. ASSERT_TRUE(owner->rc_.HasReference(inner_id)); // Check that owner's ref count for outer == 0 since the borrower task @@ -1047,8 +1091,8 @@ TEST(DistributedReferenceCountTest, TestBorrowerTree) { // The owner receives the borrower's reply and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id1, outer_id, {}, borrower1->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id1, outer_id, {}, borrower1->address_, borrower_refs); borrower1->FlushBorrowerCallbacks(); // Check that owner now has borrower in inner's borrowers list. ASSERT_TRUE(owner->rc_.HasReference(inner_id)); @@ -1066,8 +1110,8 @@ TEST(DistributedReferenceCountTest, TestBorrowerTree) { ASSERT_FALSE(borrower2->rc_.HasReference(outer_id2)); ASSERT_FALSE(borrower2->rc_.HasReference(outer_id)); - borrower1->HandleSubmittedTaskFinished(return_id2, outer_id2, {}, borrower2->address_, - borrower_refs); + borrower1->HandleSubmittedTaskFinished( + return_id2, outer_id2, {}, borrower2->address_, borrower_refs); borrower2->FlushBorrowerCallbacks(); // Borrower 1 no longer has a reference to any objects. ASSERT_FALSE(borrower1->rc_.HasReference(inner_id)); @@ -1137,8 +1181,8 @@ TEST(DistributedReferenceCountTest, TestNestedObjectNoBorrow) { // The owner receives the borrower's reply and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id, outer_id, {}, borrower->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id, outer_id, {}, borrower->address_, borrower_refs); // Check that owner now has nothing in scope. ASSERT_FALSE(owner->rc_.HasReference(outer_id)); ASSERT_FALSE(owner->rc_.HasReference(mid_id)); @@ -1201,8 +1245,8 @@ TEST(DistributedReferenceCountTest, TestNestedObject) { // The owner receives the borrower's reply and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id, outer_id, {}, borrower->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id, outer_id, {}, borrower->address_, borrower_refs); // Check that owner now has borrower in inner's borrowers list. ASSERT_TRUE(owner->rc_.HasReference(inner_id)); // Check that owner's ref count for outer and mid are 0 since the borrower @@ -1293,8 +1337,8 @@ TEST(DistributedReferenceCountTest, TestNestedObjectDifferentOwners) { // Borrower 1 should now know that borrower 2 is borrowing the inner object // ID. - borrower1->HandleSubmittedTaskFinished(return_id1, borrower_id, {}, borrower2->address_, - borrower_refs); + borrower1->HandleSubmittedTaskFinished( + return_id1, borrower_id, {}, borrower2->address_, borrower_refs); ASSERT_TRUE(borrower1->rc_.HasReference(owner_id1)); // Borrower 1 finishes. It should not have any references now because all @@ -1307,8 +1351,8 @@ TEST(DistributedReferenceCountTest, TestNestedObjectDifferentOwners) { // The owner receives the borrower's reply and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id2, owner_id3, {}, borrower1->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id2, owner_id3, {}, borrower1->address_, borrower_refs); // Check that owner now has borrower2 in inner's borrowers list. ASSERT_TRUE(owner->rc_.HasReference(owner_id1)); ASSERT_FALSE(owner->rc_.HasReference(owner_id2)); @@ -1391,8 +1435,8 @@ TEST(DistributedReferenceCountTest, TestNestedObjectDifferentOwners2) { // Borrower 1 should now know that borrower 2 is borrowing the inner object // ID. - borrower1->HandleSubmittedTaskFinished(return_id1, borrower_id, {}, borrower2->address_, - borrower_refs); + borrower1->HandleSubmittedTaskFinished( + return_id1, borrower_id, {}, borrower2->address_, borrower_refs); ASSERT_TRUE(borrower1->rc_.HasReference(owner_id1)); ASSERT_TRUE(borrower1->rc_.HasReference(owner_id2)); @@ -1403,8 +1447,8 @@ TEST(DistributedReferenceCountTest, TestNestedObjectDifferentOwners2) { // The owner receives the borrower's reply and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id2, owner_id3, {}, borrower1->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id2, owner_id3, {}, borrower1->address_, borrower_refs); // Check that owner now has borrower2 in inner's borrowers list. ASSERT_TRUE(owner->rc_.HasReference(owner_id1)); ASSERT_TRUE(owner->rc_.HasReference(owner_id2)); @@ -1481,8 +1525,8 @@ TEST(DistributedReferenceCountTest, TestBorrowerPingPong) { // The owner receives the borrower's reply and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id1, outer_id, {}, borrower->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id1, outer_id, {}, borrower->address_, borrower_refs); borrower->FlushBorrowerCallbacks(); // Check that owner now has a borrower for inner. ASSERT_TRUE(owner->rc_.HasReference(inner_id)); @@ -1498,8 +1542,8 @@ TEST(DistributedReferenceCountTest, TestBorrowerPingPong) { borrower_refs = owner->FinishExecutingTask(outer_id2, ObjectID::Nil()); ASSERT_TRUE(owner->rc_.HasReference(inner_id)); - borrower->HandleSubmittedTaskFinished(return_id2, outer_id2, {}, owner->address_, - borrower_refs); + borrower->HandleSubmittedTaskFinished( + return_id2, outer_id2, {}, owner->address_, borrower_refs); borrower->FlushBorrowerCallbacks(); // Borrower no longer has a reference to any objects. ASSERT_FALSE(borrower->rc_.HasReference(inner_id)); @@ -1566,10 +1610,10 @@ TEST(DistributedReferenceCountTest, TestDuplicateBorrower) { // The owner receives the borrower's replies and merges the borrower's ref // count into its own. - owner->HandleSubmittedTaskFinished(return_id1, outer_id, {}, borrower->address_, - borrower_refs1); - owner->HandleSubmittedTaskFinished(return_id3, outer_id, {}, borrower->address_, - borrower_refs2); + owner->HandleSubmittedTaskFinished( + return_id1, outer_id, {}, borrower->address_, borrower_refs1); + owner->HandleSubmittedTaskFinished( + return_id3, outer_id, {}, borrower->address_, borrower_refs2); borrower->FlushBorrowerCallbacks(); // Check that owner now has borrower in inner's borrowers list. ASSERT_TRUE(owner->rc_.HasReference(inner_id)); @@ -1635,8 +1679,8 @@ TEST(DistributedReferenceCountTest, TestForeignOwner) { ASSERT_FALSE(caller->rc_.HasReference(inner_id)); // Caller receives the owner's message, but inner_id is still in scope // because caller has a reference to return_id. - caller->HandleSubmittedTaskFinished(return_id, ObjectID::Nil(), - {{return_id, {inner_id}}}); + caller->HandleSubmittedTaskFinished( + return_id, ObjectID::Nil(), {{return_id, {inner_id}}}); ASSERT_TRUE(caller->rc_.HasReference(inner_id)); // @@ -1664,7 +1708,12 @@ TEST(DistributedReferenceCountTest, TestForeignOwner) { // Phase 3 -- foreign owner gets ref removed information. // // Emulate ref removed callback. - foreign_owner->rc_.AddOwnedObject(inner_id, {}, foreign_owner->address_, "", 0, false, + foreign_owner->rc_.AddOwnedObject(inner_id, + {}, + foreign_owner->address_, + "", + 0, + false, /*add_local_ref=*/false); foreign_owner->rc_.AddBorrowerAddress(inner_id, owner->address_); @@ -1717,8 +1766,8 @@ TEST(DistributedReferenceCountTest, TestDuplicateNestedObject) { borrower2->rc_.RemoveLocalReference(owner_id2, nullptr); // The nested task returns while still using owner_id1. auto borrower_refs = borrower2->FinishExecutingTask(owner_id3, ObjectID::Nil()); - owner->HandleSubmittedTaskFinished(return_id1, owner_id3, {}, borrower2->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id1, owner_id3, {}, borrower2->address_, borrower_refs); ASSERT_TRUE(borrower2->FlushBorrowerCallbacks()); // The owner submits a task that is given a reference to owner_id1. @@ -1736,8 +1785,8 @@ TEST(DistributedReferenceCountTest, TestDuplicateNestedObject) { // It should now have 2 local references to owner_id1, one from the owner and // one from the borrower. borrower_refs = borrower2->FinishExecutingTask(borrower_id, ObjectID::Nil()); - borrower1->HandleSubmittedTaskFinished(return_id3, borrower_id, {}, borrower2->address_, - borrower_refs); + borrower1->HandleSubmittedTaskFinished( + return_id3, borrower_id, {}, borrower2->address_, borrower_refs); // Borrower 1 finishes. It should not have any references now because all // state has been merged into the owner. @@ -1748,8 +1797,8 @@ TEST(DistributedReferenceCountTest, TestDuplicateNestedObject) { ASSERT_FALSE(borrower1->rc_.HasReference(borrower_id)); // Borrower 1 should not have merge any refs into the owner because borrower 2's ref was // already merged into the owner. - owner->HandleSubmittedTaskFinished(return_id2, owner_id2, {}, borrower1->address_, - borrower_refs); + owner->HandleSubmittedTaskFinished( + return_id2, owner_id2, {}, borrower1->address_, borrower_refs); // The borrower receives the owner's wait message. borrower2->FlushBorrowerCallbacks(); @@ -1791,8 +1840,8 @@ TEST(DistributedReferenceCountTest, TestReturnObjectIdNoBorrow) { // Caller's ref to the task's return ID goes out of scope before it hears // from the owner of inner_id. - caller->HandleSubmittedTaskFinished(return_id, ObjectID::Nil(), - {{return_id, {inner_id}}}); + caller->HandleSubmittedTaskFinished( + return_id, ObjectID::Nil(), {{return_id, {inner_id}}}); caller->rc_.RemoveLocalReference(return_id, nullptr); ASSERT_FALSE(caller->rc_.HasReference(return_id)); ASSERT_FALSE(caller->rc_.HasReference(inner_id)); @@ -1832,8 +1881,8 @@ TEST(DistributedReferenceCountTest, TestReturnObjectIdBorrow) { // Caller receives the owner's message, but inner_id is still in scope // because caller has a reference to return_id. - caller->HandleSubmittedTaskFinished(return_id, ObjectID::Nil(), - {{return_id, {inner_id}}}); + caller->HandleSubmittedTaskFinished( + return_id, ObjectID::Nil(), {{return_id, {inner_id}}}); ASSERT_TRUE(caller->FlushBorrowerCallbacks()); ASSERT_TRUE(owner->rc_.HasReference(inner_id)); @@ -1881,8 +1930,8 @@ TEST(DistributedReferenceCountTest, TestReturnObjectIdBorrowChain) { // Caller receives the owner's message, but inner_id is still in scope // because caller has a reference to return_id. - caller->HandleSubmittedTaskFinished(return_id, ObjectID::Nil(), - {{return_id, {inner_id}}}); + caller->HandleSubmittedTaskFinished( + return_id, ObjectID::Nil(), {{return_id, {inner_id}}}); auto return_id2 = caller->SubmitTaskWithArg(return_id); caller->rc_.RemoveLocalReference(return_id, nullptr); ASSERT_TRUE(caller->FlushBorrowerCallbacks()); @@ -1896,8 +1945,8 @@ TEST(DistributedReferenceCountTest, TestReturnObjectIdBorrowChain) { ASSERT_TRUE(borrower->rc_.HasReference(inner_id)); // Borrower merges ref count into the caller. - caller->HandleSubmittedTaskFinished(return_id2, return_id, {}, borrower->address_, - borrower_refs); + caller->HandleSubmittedTaskFinished( + return_id2, return_id, {}, borrower->address_, borrower_refs); // The caller should not have a ref count anymore because it was merged into // the owner. ASSERT_FALSE(caller->rc_.HasReference(return_id)); @@ -1953,8 +2002,8 @@ TEST(DistributedReferenceCountTest, TestReturnBorrowedId) { // Caller receives the owner's message, but inner_id is still in scope // because caller has a reference to return_id. - caller->HandleSubmittedTaskFinished(return_id, ObjectID::Nil(), - {{return_id, {inner_id}}}); + caller->HandleSubmittedTaskFinished( + return_id, ObjectID::Nil(), {{return_id, {inner_id}}}); auto borrower_return_id = caller->SubmitTaskWithArg(return_id); caller->rc_.RemoveLocalReference(return_id, nullptr); ASSERT_TRUE(caller->FlushBorrowerCallbacks()); @@ -1969,9 +2018,11 @@ TEST(DistributedReferenceCountTest, TestReturnBorrowedId) { ASSERT_TRUE(borrower->rc_.HasReference(inner_id)); // Borrower merges ref count into the caller. - caller->HandleSubmittedTaskFinished(borrower_return_id, return_id, + caller->HandleSubmittedTaskFinished(borrower_return_id, + return_id, {{borrower_return_id, {inner_id}}}, - borrower->address_, borrower_refs); + borrower->address_, + borrower_refs); // The caller should still have a ref count because it has a reference to // borrower_return_id. ASSERT_FALSE(caller->rc_.HasReference(return_id)); @@ -2039,8 +2090,8 @@ TEST(DistributedReferenceCountTest, TestReturnBorrowedIdDeserialize) { // Caller receives the owner's message, but inner_id is still in scope // because caller has a reference to return_id. - caller->HandleSubmittedTaskFinished(return_id, ObjectID::Nil(), - {{return_id, {inner_id}}}); + caller->HandleSubmittedTaskFinished( + return_id, ObjectID::Nil(), {{return_id, {inner_id}}}); auto borrower_return_id = caller->SubmitTaskWithArg(return_id); caller->rc_.RemoveLocalReference(return_id, nullptr); ASSERT_TRUE(owner->rc_.HasReference(inner_id)); @@ -2054,9 +2105,11 @@ TEST(DistributedReferenceCountTest, TestReturnBorrowedIdDeserialize) { ASSERT_TRUE(borrower->rc_.HasReference(inner_id)); // Borrower merges ref count into the caller. - caller->HandleSubmittedTaskFinished(borrower_return_id, return_id, + caller->HandleSubmittedTaskFinished(borrower_return_id, + return_id, {{borrower_return_id, {inner_id}}}, - borrower->address_, borrower_refs); + borrower->address_, + borrower_refs); // The caller should still have a ref count because it has a reference to // borrower_return_id. ASSERT_FALSE(caller->rc_.HasReference(return_id)); @@ -2123,16 +2176,16 @@ TEST(DistributedReferenceCountTest, TestReturnIdChain) { auto inner_id = ObjectID::FromRandom(); nested_worker->Put(inner_id); rpc::WorkerAddress worker_addr(worker->address_); - auto nested_refs = nested_worker->FinishExecutingTask(ObjectID::Nil(), nested_return_id, - &inner_id, &worker_addr); + auto nested_refs = nested_worker->FinishExecutingTask( + ObjectID::Nil(), nested_return_id, &inner_id, &worker_addr); nested_worker->rc_.RemoveLocalReference(inner_id, nullptr); ASSERT_TRUE(nested_worker->rc_.HasReference(inner_id)); // All task execution replies are received. - root->HandleSubmittedTaskFinished(return_id, ObjectID::Nil(), - {{return_id, {nested_return_id}}}); - worker->HandleSubmittedTaskFinished(nested_return_id, ObjectID::Nil(), - {{nested_return_id, {inner_id}}}); + root->HandleSubmittedTaskFinished( + return_id, ObjectID::Nil(), {{return_id, {nested_return_id}}}); + worker->HandleSubmittedTaskFinished( + nested_return_id, ObjectID::Nil(), {{nested_return_id, {inner_id}}}); root->FlushBorrowerCallbacks(); worker->FlushBorrowerCallbacks(); @@ -2185,14 +2238,14 @@ TEST(DistributedReferenceCountTest, TestReturnBorrowedIdChain) { auto inner_id = ObjectID::FromRandom(); nested_worker->Put(inner_id); rpc::WorkerAddress worker_addr(worker->address_); - auto nested_refs = nested_worker->FinishExecutingTask(ObjectID::Nil(), nested_return_id, - &inner_id, &worker_addr); + auto nested_refs = nested_worker->FinishExecutingTask( + ObjectID::Nil(), nested_return_id, &inner_id, &worker_addr); nested_worker->rc_.RemoveLocalReference(inner_id, nullptr); ASSERT_TRUE(nested_worker->rc_.HasReference(inner_id)); // Worker receives the reply from the nested task. - worker->HandleSubmittedTaskFinished(nested_return_id, ObjectID::Nil(), - {{nested_return_id, {inner_id}}}); + worker->HandleSubmittedTaskFinished( + nested_return_id, ObjectID::Nil(), {{nested_return_id, {inner_id}}}); worker->FlushBorrowerCallbacks(); // Worker deserializes the inner_id and returns it. worker->GetSerializedObjectId(nested_return_id, inner_id, nested_worker->address_); @@ -2208,8 +2261,8 @@ TEST(DistributedReferenceCountTest, TestReturnBorrowedIdChain) { // Root receives worker's reply, then the WaitForRefRemovedRequest from // nested_worker. - root->HandleSubmittedTaskFinished(return_id, ObjectID::Nil(), - {{return_id, {inner_id}}}); + root->HandleSubmittedTaskFinished( + return_id, ObjectID::Nil(), {{return_id, {inner_id}}}); root->FlushBorrowerCallbacks(); // Object is still in scope because root now knows that return_id contains // inner_id. @@ -2265,14 +2318,14 @@ TEST(DistributedReferenceCountTest, TestReturnBorrowedIdChainOutOfOrder) { auto inner_id = ObjectID::FromRandom(); nested_worker->Put(inner_id); rpc::WorkerAddress worker_addr(worker->address_); - auto nested_refs = nested_worker->FinishExecutingTask(ObjectID::Nil(), nested_return_id, - &inner_id, &worker_addr); + auto nested_refs = nested_worker->FinishExecutingTask( + ObjectID::Nil(), nested_return_id, &inner_id, &worker_addr); nested_worker->rc_.RemoveLocalReference(inner_id, nullptr); ASSERT_TRUE(nested_worker->rc_.HasReference(inner_id)); // Worker receives the reply from the nested task. - worker->HandleSubmittedTaskFinished(nested_return_id, ObjectID::Nil(), - {{nested_return_id, {inner_id}}}); + worker->HandleSubmittedTaskFinished( + nested_return_id, ObjectID::Nil(), {{nested_return_id, {inner_id}}}); worker->FlushBorrowerCallbacks(); // Worker deserializes the inner_id and returns it. worker->GetSerializedObjectId(nested_return_id, inner_id, nested_worker->address_); @@ -2291,8 +2344,8 @@ TEST(DistributedReferenceCountTest, TestReturnBorrowedIdChainOutOfOrder) { root->FlushBorrowerCallbacks(); ASSERT_TRUE(nested_worker->rc_.HasReference(inner_id)); - root->HandleSubmittedTaskFinished(return_id, ObjectID::Nil(), - {{return_id, {inner_id}}}); + root->HandleSubmittedTaskFinished( + return_id, ObjectID::Nil(), {{return_id, {inner_id}}}); root->rc_.RemoveLocalReference(return_id, nullptr); ASSERT_FALSE(root->rc_.HasReference(return_id)); ASSERT_FALSE(root->rc_.HasReference(inner_id)); @@ -2331,8 +2384,8 @@ TEST_F(ReferenceCountLineageEnabledTest, TestUnreconstructableObjectOutOfScope) rc->UpdateSubmittedTaskReferences({return_id}, {id}); ASSERT_TRUE(rc->IsObjectPendingCreation(return_id)); ASSERT_FALSE(*out_of_scope); - rc->UpdateFinishedTaskReferences({return_id}, {id}, false, empty_borrower, empty_refs, - &out); + rc->UpdateFinishedTaskReferences( + {return_id}, {id}, false, empty_borrower, empty_refs, &out); ASSERT_FALSE(rc->IsObjectPendingCreation(return_id)); ASSERT_FALSE(*out_of_scope); @@ -2340,8 +2393,8 @@ TEST_F(ReferenceCountLineageEnabledTest, TestUnreconstructableObjectOutOfScope) // reaches 0. rc->UpdateResubmittedTaskReferences({return_id}, {id}); ASSERT_TRUE(rc->IsObjectPendingCreation(return_id)); - rc->UpdateFinishedTaskReferences({return_id}, {id}, true, empty_borrower, empty_refs, - &out); + rc->UpdateFinishedTaskReferences( + {return_id}, {id}, true, empty_borrower, empty_refs, &out); ASSERT_FALSE(rc->IsObjectPendingCreation(return_id)); ASSERT_TRUE(*out_of_scope); } @@ -2455,8 +2508,8 @@ TEST_F(ReferenceCountLineageEnabledTest, TestEvictLineage) { // ID1 depends on ID0. rc->UpdateSubmittedTaskReferences({ids[1]}, {ids[0]}); rc->RemoveLocalReference(ids[0], nullptr); - rc->UpdateFinishedTaskReferences({ids[1]}, {ids[0]}, /*release_lineage=*/false, - empty_borrower, empty_refs, nullptr); + rc->UpdateFinishedTaskReferences( + {ids[1]}, {ids[0]}, /*release_lineage=*/false, empty_borrower, empty_refs, nullptr); bool lineage_evicted = false; for (const auto &id : ids) { @@ -2646,10 +2699,20 @@ TEST_F(ReferenceCountTest, TestDelayedWaitForRefRemoved) { // Owner owns a nested object ref, borrower is using the outer ObjectRef. ObjectID outer_id = ObjectID::FromRandom(); ObjectID inner_id = ObjectID::FromRandom(); - owner->rc_.AddOwnedObject(outer_id, {}, owner->address_, "", 0, false, + owner->rc_.AddOwnedObject(outer_id, + {}, + owner->address_, + "", + 0, + false, /*add_local_ref=*/false); owner->rc_.AddBorrowerAddress(outer_id, borrower->address_); - owner->rc_.AddOwnedObject(inner_id, {}, owner->address_, "", 0, false, + owner->rc_.AddOwnedObject(inner_id, + {}, + owner->address_, + "", + 0, + false, /*add_local_ref=*/true); ASSERT_TRUE(owner->rc_.HasReference(outer_id)); ASSERT_TRUE(owner->rc_.HasReference(inner_id)); @@ -2689,11 +2752,26 @@ TEST_F(ReferenceCountTest, TestRepeatedDeserialization) { ObjectID outer_id = ObjectID::FromRandom(); ObjectID middle_id = ObjectID::FromRandom(); ObjectID inner_id = ObjectID::FromRandom(); - owner->rc_.AddOwnedObject(inner_id, {}, owner->address_, "", 0, false, + owner->rc_.AddOwnedObject(inner_id, + {}, + owner->address_, + "", + 0, + false, /*add_local_ref=*/false); - owner->rc_.AddOwnedObject(middle_id, {inner_id}, owner->address_, "", 0, false, + owner->rc_.AddOwnedObject(middle_id, + {inner_id}, + owner->address_, + "", + 0, + false, /*add_local_ref=*/false); - owner->rc_.AddOwnedObject(outer_id, {middle_id}, owner->address_, "", 0, false, + owner->rc_.AddOwnedObject(outer_id, + {middle_id}, + owner->address_, + "", + 0, + false, /*add_local_ref=*/false); owner->rc_.AddBorrowerAddress(outer_id, borrower->address_); ASSERT_TRUE(owner->rc_.HasReference(outer_id)); @@ -2740,11 +2818,26 @@ TEST_F(ReferenceCountTest, TestForwardNestedRefs) { ObjectID outer_id = ObjectID::FromRandom(); ObjectID middle_id = ObjectID::FromRandom(); ObjectID inner_id = ObjectID::FromRandom(); - owner->rc_.AddOwnedObject(inner_id, {}, owner->address_, "", 0, false, + owner->rc_.AddOwnedObject(inner_id, + {}, + owner->address_, + "", + 0, + false, /*add_local_ref=*/false); - owner->rc_.AddOwnedObject(middle_id, {inner_id}, owner->address_, "", 0, false, + owner->rc_.AddOwnedObject(middle_id, + {inner_id}, + owner->address_, + "", + 0, + false, /*add_local_ref=*/false); - owner->rc_.AddOwnedObject(outer_id, {middle_id}, owner->address_, "", 0, false, + owner->rc_.AddOwnedObject(outer_id, + {middle_id}, + owner->address_, + "", + 0, + false, /*add_local_ref=*/false); owner->rc_.AddBorrowerAddress(outer_id, borrower1->address_); ASSERT_TRUE(owner->rc_.HasReference(outer_id)); @@ -2761,8 +2854,8 @@ TEST_F(ReferenceCountTest, TestForwardNestedRefs) { borrower2->GetSerializedObjectId(middle_id, inner_id, owner->address_); borrower2->rc_.RemoveLocalReference(middle_id, nullptr); auto borrower_refs = borrower2->FinishExecutingTask(outer_id, ObjectID::Nil()); - borrower1->HandleSubmittedTaskFinished(return_id, outer_id, {}, borrower2->address_, - borrower_refs); + borrower1->HandleSubmittedTaskFinished( + return_id, outer_id, {}, borrower2->address_, borrower_refs); borrower1->rc_.RemoveLocalReference(outer_id, nullptr); // Now the owner should contact borrower 2. diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 0a9206eed..70acd2305 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -33,8 +33,10 @@ const int kMaxUnhandledErrorScanItems = 1000; /// A class that represents a `Get` request. class GetRequest { public: - GetRequest(absl::flat_hash_set object_ids, size_t num_objects, - bool remove_after_get, bool abort_if_any_object_is_exception); + GetRequest(absl::flat_hash_set object_ids, + size_t num_objects, + bool remove_after_get, + bool abort_if_any_object_is_exception); const absl::flat_hash_set &ObjectIds() const; @@ -72,8 +74,10 @@ class GetRequest { std::condition_variable cv_; }; -GetRequest::GetRequest(absl::flat_hash_set object_ids, size_t num_objects, - bool remove_after_get, bool abort_if_any_object_is_exception_) +GetRequest::GetRequest(absl::flat_hash_set object_ids, + size_t num_objects, + bool remove_after_get, + bool abort_if_any_object_is_exception_) : object_ids_(std::move(object_ids)), num_objects_(num_objects), remove_after_get_(remove_after_get), @@ -148,9 +152,8 @@ CoreWorkerMemoryStore::CoreWorkerMemoryStore( std::shared_ptr raylet_client, std::function check_signals, std::function unhandled_exception_handler, - std::function(const ray::RayObject &object, - const ObjectID &object_id)> - object_allocator) + std::function( + const ray::RayObject &object, const ObjectID &object_id)> object_allocator) : ref_counter_(std::move(counter)), raylet_client_(raylet_client), check_signals_(check_signals), @@ -200,8 +203,8 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ if (object_allocator_ != nullptr) { object_entry = object_allocator_(object, object_id); } else { - object_entry = std::make_shared(object.GetData(), object.GetMetadata(), - object.GetNestedRefs(), true); + object_entry = std::make_shared( + object.GetData(), object.GetMetadata(), object.GetNestedRefs(), true); } bool stored_in_direct_memory = true; @@ -262,16 +265,25 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ } Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, - int num_objects, int64_t timeout_ms, - const WorkerContext &ctx, bool remove_after_get, + int num_objects, + int64_t timeout_ms, + const WorkerContext &ctx, + bool remove_after_get, std::vector> *results) { - return GetImpl(object_ids, num_objects, timeout_ms, ctx, remove_after_get, results, + return GetImpl(object_ids, + num_objects, + timeout_ms, + ctx, + remove_after_get, + results, /*abort_if_any_object_is_exception=*/true); } Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, - int num_objects, int64_t timeout_ms, - const WorkerContext &ctx, bool remove_after_get, + int num_objects, + int64_t timeout_ms, + const WorkerContext &ctx, + bool remove_after_get, std::vector> *results, bool abort_if_any_object_is_exception) { (*results).resize(object_ids.size(), nullptr); @@ -318,9 +330,10 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, size_t required_objects = num_objects - (object_ids.size() - remaining_ids.size()); // Otherwise, create a GetRequest to track remaining objects. - get_request = - std::make_shared(std::move(remaining_ids), required_objects, - remove_after_get, abort_if_any_object_is_exception); + get_request = std::make_shared(std::move(remaining_ids), + required_objects, + remove_after_get, + abort_if_any_object_is_exception); for (const auto &object_id : get_request->ObjectIds()) { object_get_requests_[object_id].push_back(get_request); } @@ -404,14 +417,19 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, } Status CoreWorkerMemoryStore::Get( - const absl::flat_hash_set &object_ids, int64_t timeout_ms, + const absl::flat_hash_set &object_ids, + int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, bool *got_exception) { const std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; - RAY_RETURN_NOT_OK(Get(id_vector, id_vector.size(), timeout_ms, ctx, - /*remove_after_get=*/false, &result_objects)); + RAY_RETURN_NOT_OK(Get(id_vector, + id_vector.size(), + timeout_ms, + ctx, + /*remove_after_get=*/false, + &result_objects)); for (size_t i = 0; i < id_vector.size(); i++) { if (result_objects[i] != nullptr) { @@ -428,13 +446,19 @@ Status CoreWorkerMemoryStore::Get( } Status CoreWorkerMemoryStore::Wait(const absl::flat_hash_set &object_ids, - int num_objects, int64_t timeout_ms, + int num_objects, + int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_set *ready) { std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; RAY_CHECK(object_ids.size() == id_vector.size()); - auto status = GetImpl(id_vector, num_objects, timeout_ms, ctx, false, &result_objects, + auto status = GetImpl(id_vector, + num_objects, + timeout_ms, + ctx, + false, + &result_objects, /*abort_if_any_object_is_exception=*/false); // Ignore TimedOut statuses since we return ready objects explicitly. if (!status.IsTimedOut()) { diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index a8d233b4f..9164c63bd 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -75,19 +75,25 @@ class CoreWorkerMemoryStore { /// finishes. This has no effect if ref counting is enabled. /// \param[out] results Result list of objects data. /// \return Status. - Status Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, - const WorkerContext &ctx, bool remove_after_get, + Status Get(const std::vector &object_ids, + int num_objects, + int64_t timeout_ms, + const WorkerContext &ctx, + bool remove_after_get, std::vector> *results); /// Convenience wrapper around Get() that stores results in a given result map. - Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, + Status Get(const absl::flat_hash_set &object_ids, + int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, bool *got_exception); /// Convenience wrapper around Get() that stores ready objects in a given result set. - Status Wait(const absl::flat_hash_set &object_ids, int num_objects, - int64_t timeout_ms, const WorkerContext &ctx, + Status Wait(const absl::flat_hash_set &object_ids, + int num_objects, + int64_t timeout_ms, + const WorkerContext &ctx, absl::flat_hash_set *ready); /// Get an object if it exists. @@ -163,8 +169,11 @@ class CoreWorkerMemoryStore { /// See the public version of `Get` for meaning of the other arguments. /// \param[in] abort_if_any_object_is_exception Whether we should abort if any object /// resources. is an exception. - Status GetImpl(const std::vector &object_ids, int num_objects, - int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, + Status GetImpl(const std::vector &object_ids, + int num_objects, + int64_t timeout_ms, + const WorkerContext &ctx, + bool remove_after_get, std::vector> *results, bool abort_if_any_object_is_exception); diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 2d2c8707a..02a46f983 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -22,7 +22,8 @@ namespace ray { namespace core { -void BufferTracker::Record(const ObjectID &object_id, TrackedBuffer *buffer, +void BufferTracker::Record(const ObjectID &object_id, + TrackedBuffer *buffer, const std::string &call_site) { absl::MutexLock lock(&active_buffers_mutex_); active_buffers_[std::make_pair(object_id, buffer)] = call_site; @@ -56,7 +57,8 @@ CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( const std::string &store_socket, const std::shared_ptr raylet_client, const std::shared_ptr reference_counter, - std::function check_signals, bool warmup, + std::function check_signals, + bool warmup, std::function get_current_call_site) : raylet_client_(raylet_client), reference_counter_(reference_counter), @@ -85,8 +87,11 @@ Status CoreWorkerPlasmaStoreProvider::Put(const RayObject &object, RAY_CHECK(!object.IsInPlasmaError()) << object_id; std::shared_ptr data; RAY_RETURN_NOT_OK(Create(object.GetMetadata(), - object.HasData() ? object.GetData()->Size() : 0, object_id, - owner_address, &data, /*created_by_worker=*/true)); + object.HasData() ? object.GetData()->Size() : 0, + object_id, + owner_address, + &data, + /*created_by_worker=*/true)); // data could be a nullptr if the ObjectID already existed, but this does // not throw an error. if (data != nullptr) { @@ -113,10 +118,15 @@ Status CoreWorkerPlasmaStoreProvider::Create(const std::shared_ptr &meta if (!created_by_worker) { source = plasma::flatbuf::ObjectSource::RestoredFromStorage; } - Status status = store_client_.CreateAndSpillIfNeeded( - object_id, owner_address, data_size, metadata ? metadata->Data() : nullptr, - metadata ? metadata->Size() : 0, data, source, - /*device_num=*/0); + Status status = + store_client_.CreateAndSpillIfNeeded(object_id, + owner_address, + data_size, + metadata ? metadata->Data() : nullptr, + metadata ? metadata->Size() : 0, + data, + source, + /*device_num=*/0); if (status.IsObjectStoreFull()) { RAY_LOG(ERROR) << "Failed to put object " << object_id @@ -152,17 +162,26 @@ Status CoreWorkerPlasmaStoreProvider::Release(const ObjectID &object_id) { } Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( - absl::flat_hash_set &remaining, const std::vector &batch_ids, - int64_t timeout_ms, bool fetch_only, bool in_direct_call, const TaskID &task_id, + absl::flat_hash_set &remaining, + const std::vector &batch_ids, + int64_t timeout_ms, + bool fetch_only, + bool in_direct_call, + const TaskID &task_id, absl::flat_hash_map> *results, bool *got_exception) { const auto owner_addresses = reference_counter_->GetOwnerAddresses(batch_ids); - RAY_RETURN_NOT_OK(raylet_client_->FetchOrReconstruct( - batch_ids, owner_addresses, fetch_only, /*mark_worker_blocked*/ !in_direct_call, - task_id)); + RAY_RETURN_NOT_OK( + raylet_client_->FetchOrReconstruct(batch_ids, + owner_addresses, + fetch_only, + /*mark_worker_blocked*/ !in_direct_call, + task_id)); std::vector plasma_results; - RAY_RETURN_NOT_OK(store_client_.Get(batch_ids, timeout_ms, &plasma_results, + RAY_RETURN_NOT_OK(store_client_.Get(batch_ids, + timeout_ms, + &plasma_results, /*is_from_worker=*/true)); // Add successfully retrieved objects to the result map and remove them from @@ -175,8 +194,8 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( if (plasma_results[i].data && plasma_results[i].data->Size()) { // We track the set of active data buffers in active_buffers_. On destruction, // the buffer entry will be removed from the set via callback. - data = std::make_shared(plasma_results[i].data, buffer_tracker_, - object_id); + data = std::make_shared( + plasma_results[i].data, buffer_tracker_, object_id); buffer_tracker_->Record(object_id, data.get(), get_current_call_site_()); } if (plasma_results[i].metadata && plasma_results[i].metadata->Size()) { @@ -201,7 +220,9 @@ Status CoreWorkerPlasmaStoreProvider::GetIfLocal( absl::flat_hash_map> *results) { std::vector plasma_results; // Since this path is used only for spilling, we should set is_from_worker: false. - RAY_RETURN_NOT_OK(store_client_.Get(object_ids, /*timeout_ms=*/0, &plasma_results, + RAY_RETURN_NOT_OK(store_client_.Get(object_ids, + /*timeout_ms=*/0, + &plasma_results, /*is_from_worker=*/false)); for (size_t i = 0; i < object_ids.size(); i++) { @@ -212,8 +233,8 @@ Status CoreWorkerPlasmaStoreProvider::GetIfLocal( if (plasma_results[i].data && plasma_results[i].data->Size()) { // We track the set of active data buffers in active_buffers_. On destruction, // the buffer entry will be removed from the set via callback. - data = std::make_shared(plasma_results[i].data, buffer_tracker_, - object_id); + data = std::make_shared( + plasma_results[i].data, buffer_tracker_, object_id); buffer_tracker_->Record(object_id, data.get(), get_current_call_site_()); } if (plasma_results[i].metadata && plasma_results[i].metadata->Size()) { @@ -243,7 +264,8 @@ Status UnblockIfNeeded(const std::shared_ptr &client, } Status CoreWorkerPlasmaStoreProvider::Get( - const absl::flat_hash_set &object_ids, int64_t timeout_ms, + const absl::flat_hash_set &object_ids, + int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, bool *got_exception) { @@ -259,10 +281,14 @@ Status CoreWorkerPlasmaStoreProvider::Get( for (int64_t i = start; i < batch_size && i < total_size; i++) { batch_ids.push_back(id_vector[start + i]); } - RAY_RETURN_NOT_OK( - FetchAndGetFromPlasmaStore(remaining, batch_ids, /*timeout_ms=*/0, - /*fetch_only=*/true, ctx.CurrentTaskIsDirectCall(), - ctx.GetCurrentTaskID(), results, got_exception)); + RAY_RETURN_NOT_OK(FetchAndGetFromPlasmaStore(remaining, + batch_ids, + /*timeout_ms=*/0, + /*fetch_only=*/true, + ctx.CurrentTaskIsDirectCall(), + ctx.GetCurrentTaskID(), + results, + got_exception)); } // If all objects were fetched already, return. Note that we always need to @@ -301,10 +327,14 @@ Status CoreWorkerPlasmaStoreProvider::Get( RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked( /*release_resources_during_plasma_fetch=*/false)); } - RAY_RETURN_NOT_OK( - FetchAndGetFromPlasmaStore(remaining, batch_ids, batch_timeout, - /*fetch_only=*/false, ctx.CurrentTaskIsDirectCall(), - ctx.GetCurrentTaskID(), results, got_exception)); + RAY_RETURN_NOT_OK(FetchAndGetFromPlasmaStore(remaining, + batch_ids, + batch_timeout, + /*fetch_only=*/false, + ctx.CurrentTaskIsDirectCall(), + ctx.GetCurrentTaskID(), + results, + got_exception)); should_break = timed_out || *got_exception; if ((previous_size - remaining.size()) < batch_ids.size()) { @@ -344,8 +374,11 @@ Status CoreWorkerPlasmaStoreProvider::Contains(const ObjectID &object_id, } Status CoreWorkerPlasmaStoreProvider::Wait( - const absl::flat_hash_set &object_ids, int num_objects, int64_t timeout_ms, - const WorkerContext &ctx, absl::flat_hash_set *ready) { + const absl::flat_hash_set &object_ids, + int num_objects, + int64_t timeout_ms, + const WorkerContext &ctx, + absl::flat_hash_set *ready) { std::vector id_vector(object_ids.begin(), object_ids.end()); bool should_break = false; @@ -366,9 +399,13 @@ Status CoreWorkerPlasmaStoreProvider::Wait( } const auto owner_addresses = reference_counter_->GetOwnerAddresses(id_vector); RAY_RETURN_NOT_OK( - raylet_client_->Wait(id_vector, owner_addresses, num_objects, call_timeout, + raylet_client_->Wait(id_vector, + owner_addresses, + num_objects, + call_timeout, /*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(), - ctx.GetCurrentTaskID(), &result_pair)); + ctx.GetCurrentTaskID(), + &result_pair)); if (result_pair.first.size() >= static_cast(num_objects)) { should_break = true; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index e1078a5d8..4bcfad197 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -35,7 +35,8 @@ class TrackedBuffer; class BufferTracker { public: // Track an object. - void Record(const ObjectID &object_id, TrackedBuffer *buffer, + void Record(const ObjectID &object_id, + TrackedBuffer *buffer, const std::string &call_site); // Release an object from tracking. void Release(const ObjectID &object_id, TrackedBuffer *buffer); @@ -58,7 +59,8 @@ class BufferTracker { class TrackedBuffer : public Buffer { public: TrackedBuffer(std::shared_ptr buffer, - const std::shared_ptr &tracker, const ObjectID &object_id) + const std::shared_ptr &tracker, + const ObjectID &object_id) : buffer_(buffer), tracker_(tracker), object_id_(object_id) {} uint8_t *Data() const override { return buffer_->Data(); } @@ -89,7 +91,8 @@ class CoreWorkerPlasmaStoreProvider { const std::string &store_socket, const std::shared_ptr raylet_client, const std::shared_ptr reference_counter, - std::function check_signals, bool warmup, + std::function check_signals, + bool warmup, std::function get_current_call_site = nullptr); ~CoreWorkerPlasmaStoreProvider(); @@ -105,8 +108,10 @@ class CoreWorkerPlasmaStoreProvider { /// \param[out] object_exists Optional. Returns whether an object with the /// same ID already exists. If this is true, then the Put does not write any /// object data. - Status Put(const RayObject &object, const ObjectID &object_id, - const rpc::Address &owner_address, bool *object_exists); + Status Put(const RayObject &object, + const ObjectID &object_id, + const rpc::Address &owner_address, + bool *object_exists); /// Create an object in plasma and return a mutable buffer to it. The buffer should be /// subsequently written to and then sealed using Seal(). @@ -116,9 +121,12 @@ class CoreWorkerPlasmaStoreProvider { /// \param[in] object_id The ID of the object. /// \param[in] owner_address The address of the object's owner. /// \param[out] data The mutable object buffer in plasma that can be written to. - Status Create(const std::shared_ptr &metadata, const size_t data_size, - const ObjectID &object_id, const rpc::Address &owner_address, - std::shared_ptr *data, bool created_by_worker); + Status Create(const std::shared_ptr &metadata, + const size_t data_size, + const ObjectID &object_id, + const rpc::Address &owner_address, + std::shared_ptr *data, + bool created_by_worker); /// Seal an object buffer created with Create(). /// @@ -137,7 +145,8 @@ class CoreWorkerPlasmaStoreProvider { /// argument to Get to retrieve the object data. Status Release(const ObjectID &object_id); - Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, + Status Get(const absl::flat_hash_set &object_ids, + int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, bool *got_exception); @@ -156,8 +165,10 @@ class CoreWorkerPlasmaStoreProvider { Status Contains(const ObjectID &object_id, bool *has_object); - Status Wait(const absl::flat_hash_set &object_ids, int num_objects, - int64_t timeout_ms, const WorkerContext &ctx, + Status Wait(const absl::flat_hash_set &object_ids, + int num_objects, + int64_t timeout_ms, + const WorkerContext &ctx, absl::flat_hash_set *ready); Status Delete(const absl::flat_hash_set &object_ids, bool local_only); @@ -188,8 +199,11 @@ class CoreWorkerPlasmaStoreProvider { /// exception. /// \return Status. Status FetchAndGetFromPlasmaStore( - absl::flat_hash_set &remaining, const std::vector &batch_ids, - int64_t timeout_ms, bool fetch_only, bool in_direct_call_task, + absl::flat_hash_set &remaining, + const std::vector &batch_ids, + int64_t timeout_ms, + bool fetch_only, + bool in_direct_call_task, const TaskID &task_id, absl::flat_hash_map> *results, bool *got_exception); diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 2b4a37070..3a52a93fe 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -29,8 +29,10 @@ const int64_t kTaskFailureThrottlingThreshold = 50; const int64_t kTaskFailureLoggingFrequencyMillis = 5000; std::vector TaskManager::AddPendingTask( - const rpc::Address &caller_address, const TaskSpecification &spec, - const std::string &call_site, int max_retries) { + const rpc::Address &caller_address, + const TaskSpecification &spec, + const std::string &call_site, + int max_retries) { RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId() << " with " << max_retries << " retries"; @@ -75,7 +77,10 @@ std::vector TaskManager::AddPendingTask( // language frontend. Note that the language bindings should set // skip_adding_local_ref=True to avoid double referencing the object. reference_counter_->AddOwnedObject(return_id, - /*inner_ids=*/{}, caller_address, call_site, -1, + /*inner_ids=*/{}, + caller_address, + call_site, + -1, /*is_reconstructable=*/is_reconstructable, /*add_local_ref=*/true); } @@ -378,7 +383,8 @@ bool TaskManager::RetryTaskIfPossible(const TaskID &task_id) { } } -void TaskManager::FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type, +void TaskManager::FailPendingTask(const TaskID &task_id, + rpc::ErrorType error_type, const Status *status, const rpc::RayErrorInfo *ray_error_info, bool mark_task_object_failed) { @@ -419,7 +425,9 @@ void TaskManager::FailPendingTask(const TaskID &task_id, rpc::ErrorType error_ty // The worker failed to execute the task, so it cannot be borrowing any // objects. - RemoveFinishedTaskReferences(spec, /*release_lineage=*/true, rpc::Address(), + RemoveFinishedTaskReferences(spec, + /*release_lineage=*/true, + rpc::Address(), ReferenceCounter::ReferenceTableProto()); if (mark_task_object_failed) { MarkTaskReturnObjectsFailed(spec, error_type, ray_error_info); @@ -428,7 +436,8 @@ void TaskManager::FailPendingTask(const TaskID &task_id, rpc::ErrorType error_ty ShutdownIfNeeded(); } -bool TaskManager::FailOrRetryPendingTask(const TaskID &task_id, rpc::ErrorType error_type, +bool TaskManager::FailOrRetryPendingTask(const TaskID &task_id, + rpc::ErrorType error_type, const Status *status, const rpc::RayErrorInfo *ray_error_info, bool mark_task_object_failed) { @@ -469,12 +478,15 @@ void TaskManager::OnTaskDependenciesInlined( reference_counter_->UpdateSubmittedTaskReferences( /*return_ids=*/{}, /*argument_ids_to_add=*/contained_ids, - /*argument_ids_to_remove=*/inlined_dependency_ids, &deleted); + /*argument_ids_to_remove=*/inlined_dependency_ids, + &deleted); in_memory_store_->Delete(deleted); } void TaskManager::RemoveFinishedTaskReferences( - TaskSpecification &spec, bool release_lineage, const rpc::Address &borrower_addr, + TaskSpecification &spec, + bool release_lineage, + const rpc::Address &borrower_addr, const ReferenceCounter::ReferenceTableProto &borrowed_refs) { std::vector plasma_dependencies; for (size_t i = 0; i < spec.NumArgs(); i++) { @@ -502,9 +514,12 @@ void TaskManager::RemoveFinishedTaskReferences( } std::vector deleted; - reference_counter_->UpdateFinishedTaskReferences(return_ids, plasma_dependencies, - release_lineage, borrower_addr, - borrowed_refs, &deleted); + reference_counter_->UpdateFinishedTaskReferences(return_ids, + plasma_dependencies, + release_lineage, + borrower_addr, + borrowed_refs, + &deleted); in_memory_store_->Delete(deleted); } diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index b22f4dc7e..c67f4b370 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -29,17 +29,20 @@ namespace core { class TaskFinisherInterface { public: - virtual void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply, + virtual void CompletePendingTask(const TaskID &task_id, + const rpc::PushTaskReply &reply, const rpc::Address &actor_addr) = 0; virtual bool RetryTaskIfPossible(const TaskID &task_id) = 0; - virtual void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type, + virtual void FailPendingTask(const TaskID &task_id, + rpc::ErrorType error_type, const Status *status = nullptr, const rpc::RayErrorInfo *ray_error_info = nullptr, bool mark_task_object_failed = true) = 0; - virtual bool FailOrRetryPendingTask(const TaskID &task_id, rpc::ErrorType error_type, + virtual bool FailOrRetryPendingTask(const TaskID &task_id, + rpc::ErrorType error_type, const Status *status, const rpc::RayErrorInfo *ray_error_info = nullptr, bool mark_task_object_failed = true) = 0; @@ -53,7 +56,8 @@ class TaskFinisherInterface { virtual bool MarkTaskCanceled(const TaskID &task_id) = 0; virtual void MarkTaskReturnObjectsFailed( - const TaskSpecification &spec, rpc::ErrorType error_type, + const TaskSpecification &spec, + rpc::ErrorType error_type, const rpc::RayErrorInfo *ray_error_info = nullptr) = 0; virtual absl::optional GetTaskSpec(const TaskID &task_id) const = 0; @@ -72,9 +76,10 @@ using PutInLocalPlasmaCallback = std::function; using RetryTaskCallback = std::function; using ReconstructObjectCallback = std::function; -using PushErrorCallback = - std::function; +using PushErrorCallback = std::function; class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterface { public: @@ -82,7 +87,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa std::shared_ptr reference_counter, PutInLocalPlasmaCallback put_in_local_plasma_callback, RetryTaskCallback retry_task_callback, - PushErrorCallback push_error_callback, int64_t max_lineage_bytes) + PushErrorCallback push_error_callback, + int64_t max_lineage_bytes) : in_memory_store_(in_memory_store), reference_counter_(reference_counter), put_in_local_plasma_callback_(put_in_local_plasma_callback), @@ -137,7 +143,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// \param[in] reply Proto response to a direct actor or task call. /// \param[in] worker_addr Address of the worker that executed the task. /// \return Void. - void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply, + void CompletePendingTask(const TaskID &task_id, + const rpc::PushTaskReply &reply, const rpc::Address &worker_addr) override; bool RetryTaskIfPossible(const TaskID &task_id) override; @@ -154,7 +161,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// \param[in] mark_task_object_failed whether or not it marks the task /// return object as failed. /// \return Whether the task will be retried or not. - bool FailOrRetryPendingTask(const TaskID &task_id, rpc::ErrorType error_type, + bool FailOrRetryPendingTask(const TaskID &task_id, + rpc::ErrorType error_type, const Status *status = nullptr, const rpc::RayErrorInfo *ray_error_info = nullptr, bool mark_task_object_failed = true) override; @@ -169,7 +177,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// \param[in] ray_error_info The error information of a given error type. /// \param[in] mark_task_object_failed whether or not it marks the task /// return object as failed. - void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type, + void FailPendingTask(const TaskID &task_id, + rpc::ErrorType error_type, const Status *status = nullptr, const rpc::RayErrorInfo *ray_error_info = nullptr, bool mark_task_object_failed = true) override; @@ -181,7 +190,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// \param[in] error_type The error type the returned Ray object will store. /// \param[in] ray_error_info The error information of a given error type. void MarkTaskReturnObjectsFailed( - const TaskSpecification &spec, rpc::ErrorType error_type, + const TaskSpecification &spec, + rpc::ErrorType error_type, const rpc::RayErrorInfo *ray_error_info = nullptr) override LOCKS_EXCLUDED(mu_); /// A task's dependencies were inlined in the task spec. This will decrement @@ -248,7 +258,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa private: struct TaskEntry { - TaskEntry(const TaskSpecification &spec_arg, int num_retries_left_arg, + TaskEntry(const TaskSpecification &spec_arg, + int num_retries_left_arg, size_t num_returns) : spec(spec_arg), num_retries_left(num_retries_left_arg) { for (size_t i = 0; i < num_returns; i++) { @@ -312,7 +323,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// failed. The remaining dependencies are plasma objects and any ObjectIDs /// that were inlined in the task spec. void RemoveFinishedTaskReferences( - TaskSpecification &spec, bool release_lineage, const rpc::Address &worker_addr, + TaskSpecification &spec, + bool release_lineage, + const rpc::Address &worker_addr, const ReferenceCounter::ReferenceTableProto &borrowed_refs); /// Shutdown if all tasks are finished and shutdown is scheduled. diff --git a/src/ray/core_worker/test/actor_creator_test.cc b/src/ray/core_worker/test/actor_creator_test.cc index bedad5692..06cb76d93 100644 --- a/src/ray/core_worker/test/actor_creator_test.cc +++ b/src/ray/core_worker/test/actor_creator_test.cc @@ -86,7 +86,8 @@ int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLog::ShutDownRayLog, + argv[0], ray::RayLogLevel::INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(argv[0]); diff --git a/src/ray/core_worker/test/actor_manager_test.cc b/src/ray/core_worker/test/actor_manager_test.cc index 189783292..c355e5fc8 100644 --- a/src/ray/core_worker/test/actor_manager_test.cc +++ b/src/ray/core_worker/test/actor_manager_test.cc @@ -72,17 +72,24 @@ class MockGcsClient : public gcs::GcsClient { class MockDirectActorSubmitter : public CoreWorkerDirectActorTaskSubmitterInterface { public: MockDirectActorSubmitter() : CoreWorkerDirectActorTaskSubmitterInterface() {} - void AddActorQueueIfNotExists(const ActorID &actor_id, int32_t max_pending_calls, + void AddActorQueueIfNotExists(const ActorID &actor_id, + int32_t max_pending_calls, bool execute_out_of_order = false) override { AddActorQueueIfNotExists_(actor_id, max_pending_calls, execute_out_of_order); } MOCK_METHOD3(AddActorQueueIfNotExists_, - void(const ActorID &actor_id, int32_t max_pending_calls, + void(const ActorID &actor_id, + int32_t max_pending_calls, bool execute_out_of_order)); - MOCK_METHOD3(ConnectActor, void(const ActorID &actor_id, const rpc::Address &address, - int64_t num_restarts)); - MOCK_METHOD4(DisconnectActor, void(const ActorID &actor_id, int64_t num_restarts, - bool dead, const rpc::ActorDeathCause &death_cause)); + MOCK_METHOD3(ConnectActor, + void(const ActorID &actor_id, + const rpc::Address &address, + int64_t num_restarts)); + MOCK_METHOD4(DisconnectActor, + void(const ActorID &actor_id, + int64_t num_restarts, + bool dead, + const rpc::ActorDeathCause &death_cause)); MOCK_METHOD3(KillActor, void(const ActorID &actor_id, bool force_kill, bool no_restart)); @@ -99,14 +106,18 @@ class MockReferenceCounter : public ReferenceCounterInterface { void(const ObjectID &object_id, const std::string &call_sit)); MOCK_METHOD4(AddBorrowedObject, - bool(const ObjectID &object_id, const ObjectID &outer_id, + bool(const ObjectID &object_id, + const ObjectID &outer_id, const rpc::Address &owner_address, bool foreign_owner_already_monitoring)); MOCK_METHOD8(AddOwnedObject, - void(const ObjectID &object_id, const std::vector &contained_ids, - const rpc::Address &owner_address, const std::string &call_site, - const int64_t object_size, bool is_reconstructable, + void(const ObjectID &object_id, + const std::vector &contained_ids, + const rpc::Address &owner_address, + const std::string &call_site, + const int64_t object_size, + bool is_reconstructable, bool add_local_ref, const absl::optional &pinned_at_raylet_id)); @@ -146,13 +157,24 @@ class ActorManagerTest : public ::testing::Test { RayFunction function(Language::PYTHON, FunctionDescriptorBuilder::BuildPython("", "", "", "")); - auto actor_handle = absl::make_unique( - actor_id, TaskID::Nil(), rpc::Address(), job_id, ObjectID::FromRandom(), - function.GetLanguage(), function.GetFunctionDescriptor(), "", 0, "", "", -1, - false); + auto actor_handle = absl::make_unique(actor_id, + TaskID::Nil(), + rpc::Address(), + job_id, + ObjectID::FromRandom(), + function.GetLanguage(), + function.GetFunctionDescriptor(), + "", + 0, + "", + "", + -1, + false); EXPECT_CALL(*reference_counter_, SetDeleteCallback(_, _)) .WillRepeatedly(testing::Return(true)); - actor_manager_->AddNewActorHandle(move(actor_handle), call_site, caller_address, + actor_manager_->AddNewActorHandle(move(actor_handle), + call_site, + caller_address, /*is_detached*/ false); return actor_id; } @@ -173,25 +195,45 @@ TEST_F(ActorManagerTest, TestAddAndGetActorHandleEndToEnd) { const auto call_site = ""; RayFunction function(Language::PYTHON, FunctionDescriptorBuilder::BuildPython("", "", "", "")); - auto actor_handle = absl::make_unique( - actor_id, TaskID::Nil(), rpc::Address(), job_id, ObjectID::FromRandom(), - function.GetLanguage(), function.GetFunctionDescriptor(), "", 0, "", "", -1, false); + auto actor_handle = absl::make_unique(actor_id, + TaskID::Nil(), + rpc::Address(), + job_id, + ObjectID::FromRandom(), + function.GetLanguage(), + function.GetFunctionDescriptor(), + "", + 0, + "", + "", + -1, + false); EXPECT_CALL(*reference_counter_, SetDeleteCallback(_, _)) .WillRepeatedly(testing::Return(true)); // Add an actor handle. - ASSERT_TRUE(actor_manager_->AddNewActorHandle(move(actor_handle), call_site, - caller_address, false)); + ASSERT_TRUE(actor_manager_->AddNewActorHandle( + move(actor_handle), call_site, caller_address, false)); // Make sure the subscription request is sent to GCS. ASSERT_TRUE(actor_info_accessor_->CheckSubscriptionRequested(actor_id)); ASSERT_TRUE(actor_manager_->CheckActorHandleExists(actor_id)); - auto actor_handle2 = absl::make_unique( - actor_id, TaskID::Nil(), rpc::Address(), job_id, ObjectID::FromRandom(), - function.GetLanguage(), function.GetFunctionDescriptor(), "", 0, "", "", -1, false); + auto actor_handle2 = absl::make_unique(actor_id, + TaskID::Nil(), + rpc::Address(), + job_id, + ObjectID::FromRandom(), + function.GetLanguage(), + function.GetFunctionDescriptor(), + "", + 0, + "", + "", + -1, + false); // Make sure the same actor id adding will return false. - ASSERT_FALSE(actor_manager_->AddNewActorHandle(move(actor_handle2), call_site, - caller_address, false)); + ASSERT_FALSE(actor_manager_->AddNewActorHandle( + move(actor_handle2), call_site, caller_address, false)); // Make sure we can get an actor handle correctly. const std::shared_ptr actor_handle_to_get = actor_manager_->GetActorHandle(actor_id); @@ -226,9 +268,19 @@ TEST_F(ActorManagerTest, RegisterActorHandles) { const auto call_site = ""; RayFunction function(Language::PYTHON, FunctionDescriptorBuilder::BuildPython("", "", "", "")); - auto actor_handle = absl::make_unique( - actor_id, TaskID::Nil(), rpc::Address(), job_id, ObjectID::FromRandom(), - function.GetLanguage(), function.GetFunctionDescriptor(), "", 0, "", "", -1, false); + auto actor_handle = absl::make_unique(actor_id, + TaskID::Nil(), + rpc::Address(), + job_id, + ObjectID::FromRandom(), + function.GetLanguage(), + function.GetFunctionDescriptor(), + "", + 0, + "", + "", + -1, + false); EXPECT_CALL(*reference_counter_, SetDeleteCallback(_, _)) .WillRepeatedly(testing::Return(true)); ObjectID outer_object_id = ObjectID::Nil(); diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index c62055bc4..fd6ee001b 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -61,8 +61,9 @@ ActorID CreateActorHelper(std::unordered_map &resources, uint8_t array[] = {1, 2, 3}; auto buffer = std::make_shared(array, sizeof(array)); - RayFunction func(Language::PYTHON, FunctionDescriptorBuilder::BuildPython( - "actor creation task", "", "", "")); + RayFunction func( + Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("actor creation task", "", "", "")); std::vector> args; args.emplace_back(new TaskArgByValue( std::make_shared(buffer, nullptr, std::vector()))); @@ -117,7 +118,9 @@ class CoreWorkerTest : public ::testing::Test { // a task can be scheduled to the desired node. for (int i = 0; i < num_nodes; i++) { raylet_socket_names_[i] = - TestSetupUtil::StartRaylet("127.0.0.1", node_manager_port + i, "127.0.0.1", + TestSetupUtil::StartRaylet("127.0.0.1", + node_manager_port + i, + "127.0.0.1", "\"CPU,4.0,resource" + std::to_string(i) + ",10\"", &raylet_store_socket_names_[i]); } @@ -184,7 +187,8 @@ class CoreWorkerTest : public ::testing::Test { void TestActorRestart(std::unordered_map &resources); protected: - bool WaitForDirectCallActorState(const ActorID &actor_id, bool wait_alive, + bool WaitForDirectCallActorState(const ActorID &actor_id, + bool wait_alive, int timeout_ms); // Get the pid for the worker process that runs the actor. @@ -198,7 +202,8 @@ class CoreWorkerTest : public ::testing::Test { std::string gcs_server_socket_name_; }; -bool CoreWorkerTest::WaitForDirectCallActorState(const ActorID &actor_id, bool wait_alive, +bool CoreWorkerTest::WaitForDirectCallActorState(const ActorID &actor_id, + bool wait_alive, int timeout_ms) { auto condition_func = [actor_id, wait_alive]() -> bool { bool actor_alive = @@ -246,7 +251,8 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res ObjectID object_id; RAY_CHECK_OK( - driver.Put(RayObject(buffer2, nullptr, std::vector()), {}, + driver.Put(RayObject(buffer2, nullptr, std::vector()), + {}, &object_id)); std::vector> args; @@ -255,15 +261,19 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res args.emplace_back( new TaskArgByReference(object_id, driver.GetRpcAddress(), /*call_site=*/"")); - RayFunction func(Language::PYTHON, FunctionDescriptorBuilder::BuildPython( - "MergeInputArgsAsOutput", "", "", "")); + RayFunction func( + Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("MergeInputArgsAsOutput", "", "", "")); TaskOptions options; rpc::SchedulingStrategy scheduling_strategy; scheduling_strategy.mutable_default_scheduling_strategy(); - auto return_refs = - driver.SubmitTask(func, args, options, /*max_retries=*/0, - /*retry_exceptions=*/false, scheduling_strategy, - /*debugger_breakpoint=*/""); + auto return_refs = driver.SubmitTask(func, + args, + options, + /*max_retries=*/0, + /*retry_exceptions=*/false, + scheduling_strategy, + /*debugger_breakpoint=*/""); auto return_ids = ObjectRefsToIds(return_refs); ASSERT_EQ(return_ids.size(), 1); @@ -275,7 +285,8 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res ASSERT_EQ(results[0]->GetData()->Size(), buffer1->Size() + buffer2->Size()); ASSERT_EQ(memcmp(results[0]->GetData()->Data(), buffer1->Data(), buffer1->Size()), 0); - ASSERT_EQ(memcmp(results[0]->GetData()->Data() + buffer1->Size(), buffer2->Data(), + ASSERT_EQ(memcmp(results[0]->GetData()->Data() + buffer1->Size(), + buffer2->Data(), buffer2->Size()), 0); } @@ -302,8 +313,9 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso buffer2, nullptr, std::vector()))); TaskOptions options{"", 1, resources}; - RayFunction func(Language::PYTHON, FunctionDescriptorBuilder::BuildPython( - "MergeInputArgsAsOutput", "", "", "")); + RayFunction func( + Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("MergeInputArgsAsOutput", "", "", "")); auto return_ids = ObjectRefsToIds(driver.SubmitActorTask(actor_id, func, args, options).value()); @@ -319,7 +331,8 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso ASSERT_EQ(results[0]->GetData()->Size(), buffer1->Size() + buffer2->Size()); ASSERT_EQ(memcmp(results[0]->GetData()->Data(), buffer1->Data(), buffer1->Size()), 0); - ASSERT_EQ(memcmp(results[0]->GetData()->Data() + buffer1->Size(), buffer2->Data(), + ASSERT_EQ(memcmp(results[0]->GetData()->Data() + buffer1->Size(), + buffer2->Data(), buffer2->Size()), 0); } @@ -335,7 +348,8 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso ObjectID object_id; RAY_CHECK_OK( - driver.Put(RayObject(buffer1, nullptr, std::vector()), {}, + driver.Put(RayObject(buffer1, nullptr, std::vector()), + {}, &object_id)); // Create arguments with PassByRef and PassByValue. @@ -346,8 +360,9 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso buffer2, nullptr, std::vector()))); TaskOptions options{"", 1, resources}; - RayFunction func(Language::PYTHON, FunctionDescriptorBuilder::BuildPython( - "MergeInputArgsAsOutput", "", "", "")); + RayFunction func( + Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("MergeInputArgsAsOutput", "", "", "")); auto return_ids = ObjectRefsToIds(driver.SubmitActorTask(actor_id, func, args, options).value()); @@ -359,7 +374,8 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso ASSERT_EQ(results.size(), 1); ASSERT_EQ(results[0]->GetData()->Size(), buffer1->Size() + buffer2->Size()); ASSERT_EQ(memcmp(results[0]->GetData()->Data(), buffer1->Data(), buffer1->Size()), 0); - ASSERT_EQ(memcmp(results[0]->GetData()->Data() + buffer1->Size(), buffer2->Data(), + ASSERT_EQ(memcmp(results[0]->GetData()->Data() + buffer1->Size(), + buffer2->Data(), buffer2->Size()), 0); } @@ -408,8 +424,9 @@ void CoreWorkerTest::TestActorRestart( buffer1, nullptr, std::vector()))); TaskOptions options{"", 1, resources}; - RayFunction func(Language::PYTHON, FunctionDescriptorBuilder::BuildPython( - "MergeInputArgsAsOutput", "", "", "")); + RayFunction func( + Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("MergeInputArgsAsOutput", "", "", "")); auto return_ids = ObjectRefsToIds(driver.SubmitActorTask(actor_id, func, args, options).value()); @@ -451,8 +468,9 @@ void CoreWorkerTest::TestActorFailure( buffer1, nullptr, std::vector()))); TaskOptions options{"", 1, resources}; - RayFunction func(Language::PYTHON, FunctionDescriptorBuilder::BuildPython( - "MergeInputArgsAsOutput", "", "", "")); + RayFunction func( + Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("MergeInputArgsAsOutput", "", "", "")); auto return_ids = ObjectRefsToIds(driver.SubmitActorTask(actor_id, func, args, options).value()); @@ -527,9 +545,17 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { scheduling_strategy}; const auto job_id = NextJobId(); ActorHandle actor_handle(ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1), - TaskID::Nil(), rpc::Address(), job_id, ObjectID::FromRandom(), - function.GetLanguage(), function.GetFunctionDescriptor(), "", - 0, "", "", -1); + TaskID::Nil(), + rpc::Address(), + job_id, + ObjectID::FromRandom(), + function.GetLanguage(), + function.GetFunctionDescriptor(), + "", + 0, + "", + "", + -1); // Manually create `num_tasks` task specs, and for each of them create a // `PushTaskRequest`, this is to batch performance of TaskSpec @@ -543,10 +569,20 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { auto num_returns = options.num_returns; TaskSpecBuilder builder; - builder.SetCommonTaskSpec(RandomTaskId(), options.name, function.GetLanguage(), - function.GetFunctionDescriptor(), job_id, RandomTaskId(), 0, - RandomTaskId(), address, num_returns, resources, resources, - "", 0); + builder.SetCommonTaskSpec(RandomTaskId(), + options.name, + function.GetLanguage(), + function.GetFunctionDescriptor(), + job_id, + RandomTaskId(), + 0, + RandomTaskId(), + address, + num_returns, + resources, + resources, + "", + 0); // Set task arguments. for (const auto &arg : args) { builder.AddArg(*arg); @@ -587,8 +623,9 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { buffer, nullptr, std::vector()))); TaskOptions options{"", 1, resources}; - RayFunction func(Language::PYTHON, FunctionDescriptorBuilder::BuildPython( - "MergeInputArgsAsOutput", "", "", "")); + RayFunction func( + Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("MergeInputArgsAsOutput", "", "", "")); auto return_ids = ObjectRefsToIds(driver.SubmitActorTask(actor_id, func, args, options).value()); @@ -646,10 +683,18 @@ TEST_F(ZeroNodeTest, TestWorkerContext) { TEST_F(ZeroNodeTest, TestActorHandle) { // Test actor handle serialization and deserialization round trip. JobID job_id = NextJobId(); - ActorHandle original( - ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 0), TaskID::Nil(), - rpc::Address(), job_id, ObjectID::FromRandom(), Language::PYTHON, - FunctionDescriptorBuilder::BuildPython("", "", "", ""), "", 0, "", "", -1); + ActorHandle original(ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 0), + TaskID::Nil(), + rpc::Address(), + job_id, + ObjectID::FromRandom(), + Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("", "", "", ""), + "", + 0, + "", + "", + -1); std::string output; original.Serialize(&output); ActorHandle deserialized(output); @@ -711,12 +756,14 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { for (size_t i = 0; i < ids.size(); i++) { const auto &expected = buffers[i]; ASSERT_EQ(results[ids[i]]->GetData()->Size(), expected.GetData()->Size()); - ASSERT_EQ(memcmp(results[ids[i]]->GetData()->Data(), expected.GetData()->Data(), + ASSERT_EQ(memcmp(results[ids[i]]->GetData()->Data(), + expected.GetData()->Data(), expected.GetData()->Size()), 0); ASSERT_EQ(results[ids[i]]->GetMetadata()->Size(), expected.GetMetadata()->Size()); ASSERT_EQ(memcmp(results[ids[i]]->GetMetadata()->Data(), - expected.GetMetadata()->Data(), expected.GetMetadata()->Size()), + expected.GetMetadata()->Data(), + expected.GetMetadata()->Size()), 0); } @@ -870,17 +917,23 @@ TEST_F(SingleNodeTest, TestCancelTasks) { scheduling_strategy.mutable_default_scheduling_strategy(); // Submit func1. The function should start looping forever. - auto return_ids1 = - ObjectRefsToIds(driver.SubmitTask(func1, args, options, /*max_retries=*/0, - /*retry_exceptions=*/false, scheduling_strategy, - /*debugger_breakpoint=*/"")); + auto return_ids1 = ObjectRefsToIds(driver.SubmitTask(func1, + args, + options, + /*max_retries=*/0, + /*retry_exceptions=*/false, + scheduling_strategy, + /*debugger_breakpoint=*/"")); ASSERT_EQ(return_ids1.size(), 1); // Submit func2. The function should be queued at the worker indefinitely. - auto return_ids2 = - ObjectRefsToIds(driver.SubmitTask(func2, args, options, /*max_retries=*/0, - /*retry_exceptions=*/false, scheduling_strategy, - /*debugger_breakpoint=*/"")); + auto return_ids2 = ObjectRefsToIds(driver.SubmitTask(func2, + args, + options, + /*max_retries=*/0, + /*retry_exceptions=*/false, + scheduling_strategy, + /*debugger_breakpoint=*/"")); ASSERT_EQ(return_ids2.size(), 1); // Cancel func2 by removing it from the worker's queue diff --git a/src/ray/core_worker/test/direct_actor_transport_mock_test.cc b/src/ray/core_worker/test/direct_actor_transport_mock_test.cc index 88a07a863..1d440a511 100644 --- a/src/ray/core_worker/test/direct_actor_transport_mock_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_mock_test.cc @@ -94,9 +94,10 @@ TEST_F(DirectTaskTransportTest, ActorRegisterFailure) { ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id)); actor_task_submitter->AddActorQueueIfNotExists(actor_id, -1); ASSERT_TRUE(CheckSubmitTask(task_spec)); - EXPECT_CALL(*task_finisher, FailOrRetryPendingTask( - task_spec.TaskId(), - rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, _, _, _)); + EXPECT_CALL( + *task_finisher, + FailOrRetryPendingTask( + task_spec.TaskId(), rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, _, _, _)); register_cb(Status::IOError("")); } diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index 0d2ce23f1..8f4df1d62 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -43,7 +43,8 @@ rpc::ActorDeathCause CreateMockDeathCause() { return death_cause; } -TaskSpecification CreateActorTaskHelper(ActorID actor_id, WorkerID caller_worker_id, +TaskSpecification CreateActorTaskHelper(ActorID actor_id, + WorkerID caller_worker_id, int64_t counter, TaskID caller_id = TaskID::Nil()) { TaskSpecification task; @@ -58,7 +59,8 @@ TaskSpecification CreateActorTaskHelper(ActorID actor_id, WorkerID caller_worker return task; } -rpc::PushTaskRequest CreatePushTaskRequestHelper(ActorID actor_id, int64_t counter, +rpc::PushTaskRequest CreatePushTaskRequestHelper(ActorID actor_id, + int64_t counter, WorkerID caller_worker_id, TaskID caller_id, int64_t caller_timestamp) { @@ -75,7 +77,8 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { public: const rpc::Address &Addr() const override { return addr; } - void PushActorTask(std::unique_ptr request, bool skip_queue, + void PushActorTask(std::unique_ptr request, + bool skip_queue, const rpc::ClientCallback &callback) override { received_seq_nos.push_back(request->sequence_number()); callbacks.push_back(callback); @@ -112,7 +115,10 @@ class DirectActorSubmitterTest : public ::testing::TestWithParam { task_finisher_(std::make_shared()), io_work(io_context), submitter_( - *client_pool_, *store_, *task_finisher_, actor_creator_, + *client_pool_, + *store_, + *task_finisher_, + actor_creator_, [this](const ActorID &actor_id, int64_t num_queued) { last_queue_warning_ = num_queued; }, @@ -668,13 +674,15 @@ TEST_P(DirectActorSubmitterTest, TestPendingTasks) { ASSERT_FALSE(submitter_.PendingTasksFull(actor_id)); } -INSTANTIATE_TEST_SUITE_P(ExecuteOutOfOrder, DirectActorSubmitterTest, +INSTANTIATE_TEST_SUITE_P(ExecuteOutOfOrder, + DirectActorSubmitterTest, ::testing::Values(true, false)); class MockDependencyWaiter : public DependencyWaiter { public: - MOCK_METHOD2(Wait, void(const std::vector &dependencies, - std::function on_dependencies_available)); + MOCK_METHOD2(Wait, + void(const std::vector &dependencies, + std::function on_dependencies_available)); virtual ~MockDependencyWaiter() {} }; @@ -693,8 +701,8 @@ class MockCoreWorkerDirectTaskReceiver : public CoreWorkerDirectTaskReceiver { instrumented_io_context &main_io_service, const TaskHandler &task_handler, const OnTaskDone &task_done) - : CoreWorkerDirectTaskReceiver(worker_context, main_io_service, task_handler, - task_done) {} + : CoreWorkerDirectTaskReceiver( + worker_context, main_io_service, task_handler, task_done) {} void UpdateConcurrencyGroupsCache(const ActorID &actor_id, const std::vector &cgs) { @@ -708,14 +716,18 @@ class DirectActorReceiverTest : public ::testing::Test { : worker_context_(WorkerType::WORKER, JobID::FromInt(0)), worker_client_(std::shared_ptr(new MockWorkerClient())), dependency_waiter_(std::make_shared()) { - auto execute_task = - std::bind(&DirectActorReceiverTest::MockExecuteTask, this, std::placeholders::_1, - std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); + auto execute_task = std::bind(&DirectActorReceiverTest::MockExecuteTask, + this, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3, + std::placeholders::_4); receiver_ = std::make_unique( worker_context_, main_io_service_, execute_task, [] { return Status::OK(); }); receiver_->Init(std::make_shared( [&](const rpc::Address &addr) { return worker_client_; }), - rpc_address_, dependency_waiter_); + rpc_address_, + dependency_waiter_); } Status MockExecuteTask(const TaskSpecification &task_spec, @@ -763,7 +775,8 @@ TEST_F(DirectActorReceiverTest, TestNewTaskFromDifferentWorker) { auto request = CreatePushTaskRequestHelper(actor_id, 0, worker_id, caller_id, curr_timestamp); rpc::PushTaskReply reply; - auto reply_callback = [&callback_count](Status status, std::function success, + auto reply_callback = [&callback_count](Status status, + std::function success, std::function failure) { ++callback_count; ASSERT_TRUE(status.ok()); @@ -778,7 +791,8 @@ TEST_F(DirectActorReceiverTest, TestNewTaskFromDifferentWorker) { auto request = CreatePushTaskRequestHelper(actor_id, 1, worker_id, caller_id, curr_timestamp); rpc::PushTaskReply reply; - auto reply_callback = [&callback_count](Status status, std::function success, + auto reply_callback = [&callback_count](Status status, + std::function success, std::function failure) { ++callback_count; ASSERT_TRUE(status.ok()); @@ -796,7 +810,8 @@ TEST_F(DirectActorReceiverTest, TestNewTaskFromDifferentWorker) { auto request = CreatePushTaskRequestHelper(actor_id, 0, worker_id, caller_id, new_timestamp); rpc::PushTaskReply reply; - auto reply_callback = [&callback_count](Status status, std::function success, + auto reply_callback = [&callback_count](Status status, + std::function success, std::function failure) { ++callback_count; ASSERT_TRUE(status.ok()); @@ -811,7 +826,8 @@ TEST_F(DirectActorReceiverTest, TestNewTaskFromDifferentWorker) { auto request = CreatePushTaskRequestHelper(actor_id, 1, worker_id, caller_id, old_timestamp); rpc::PushTaskReply reply; - auto reply_callback = [&callback_count](Status status, std::function success, + auto reply_callback = [&callback_count](Status status, + std::function success, std::function failure) { ++callback_count; ASSERT_TRUE(!status.ok()); @@ -836,7 +852,8 @@ int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLog::ShutDownRayLog, + argv[0], ray::RayLogLevel::INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(argv[0]); diff --git a/src/ray/core_worker/test/direct_task_transport_mock_test.cc b/src/ray/core_worker/test/direct_task_transport_mock_test.cc index 0201ddc8a..b5628692e 100644 --- a/src/ray/core_worker/test/direct_task_transport_mock_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_mock_test.cc @@ -40,11 +40,13 @@ class DirectTaskTransportTest : public ::testing::Test { client_pool, /* core_worker_client_pool */ nullptr, /* lease_client_factory */ lease_policy, /* lease_policy */ - std::make_shared(), task_finisher, + std::make_shared(), + task_finisher, NodeID::Nil(), /* local_raylet_id */ WorkerType::WORKER, /* worker_type */ 0, /* lease_timeout_ms */ - actor_creator, JobID::Nil() /* job_id */); + actor_creator, + JobID::Nil() /* job_id */); } TaskSpecification GetCreatingTaskSpec(const ActorID &actor_id) { @@ -80,8 +82,8 @@ TEST_F(DirectTaskTransportTest, ActorCreationFail) { auto task_spec = GetCreatingTaskSpec(actor_id); EXPECT_CALL(*task_finisher, CompletePendingTask(_, _, _)).Times(0); EXPECT_CALL(*task_finisher, - FailOrRetryPendingTask(task_spec.TaskId(), - rpc::ErrorType::ACTOR_CREATION_FAILED, _, _, true)); + FailOrRetryPendingTask( + task_spec.TaskId(), rpc::ErrorType::ACTOR_CREATION_FAILED, _, _, true)); rpc::ClientCallback create_cb; EXPECT_CALL(*actor_creator, AsyncCreateActor(task_spec, _)) .WillOnce(DoAll(SaveArg<1>(&create_cb), Return(Status::OK()))); diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 5501c3913..91329f3c2 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -36,10 +36,20 @@ TaskSpecification BuildTaskSpec(const std::unordered_map &r std::string serialized_runtime_env = "") { TaskSpecBuilder builder; rpc::Address empty_address; - builder.SetCommonTaskSpec(TaskID::Nil(), "dummy_task", Language::PYTHON, - function_descriptor, JobID::Nil(), TaskID::Nil(), 0, - TaskID::Nil(), empty_address, 1, resources, resources, - serialized_runtime_env, depth); + builder.SetCommonTaskSpec(TaskID::Nil(), + "dummy_task", + Language::PYTHON, + function_descriptor, + JobID::Nil(), + TaskID::Nil(), + 0, + TaskID::Nil(), + empty_address, + 1, + resources, + resources, + serialized_runtime_env, + depth); return builder.Build(); } // Calls BuildTaskSpec with empty resources map and empty function descriptor @@ -52,7 +62,8 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { callbacks.push_back(callback); } - bool ReplyPushTask(Status status = Status::OK(), bool exit = false, + bool ReplyPushTask(Status status = Status::OK(), + bool exit = false, bool is_application_level_error = false) { if (callbacks.size() == 0) { return false; @@ -83,7 +94,8 @@ class MockTaskFinisher : public TaskFinisherInterface { public: MockTaskFinisher() {} - void CompletePendingTask(const TaskID &, const rpc::PushTaskReply &, + void CompletePendingTask(const TaskID &, + const rpc::PushTaskReply &, const rpc::Address &actor_addr) override { num_tasks_complete++; } @@ -93,14 +105,16 @@ class MockTaskFinisher : public TaskFinisherInterface { return false; } - void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type, + void FailPendingTask(const TaskID &task_id, + rpc::ErrorType error_type, const Status *status, const rpc::RayErrorInfo *ray_error_info = nullptr, bool mark_task_object_failed = true) override { num_fail_pending_task_calls++; } - bool FailOrRetryPendingTask(const TaskID &task_id, rpc::ErrorType error_type, + bool FailOrRetryPendingTask(const TaskID &task_id, + rpc::ErrorType error_type, const Status *status, const rpc::RayErrorInfo *ray_error_info = nullptr, bool mark_task_object_failed = true) override { @@ -115,7 +129,8 @@ class MockTaskFinisher : public TaskFinisherInterface { } void MarkTaskReturnObjectsFailed( - const TaskSpecification &spec, rpc::ErrorType error_type, + const TaskSpecification &spec, + rpc::ErrorType error_type, const rpc::RayErrorInfo *ray_error_info = nullptr) override {} bool MarkTaskCanceled(const TaskID &task_id) override { return true; } @@ -137,7 +152,9 @@ class MockTaskFinisher : public TaskFinisherInterface { class MockRayletClient : public WorkerLeaseInterface { public: - Status ReturnWorker(int worker_port, const WorkerID &worker_id, bool disconnect_worker, + Status ReturnWorker(int worker_port, + const WorkerID &worker_id, + bool disconnect_worker, bool worker_exiting) override { if (disconnect_worker) { num_workers_disconnected++; @@ -164,9 +181,11 @@ class MockRayletClient : public WorkerLeaseInterface { } void RequestWorkerLease( - const rpc::TaskSpec &task_spec, bool grant_or_reject, + const rpc::TaskSpec &task_spec, + bool grant_or_reject, const ray::rpc::ClientCallback &callback, - const int64_t backlog_size, const bool is_selected_based_on_locality) override { + const int64_t backlog_size, + const bool is_selected_based_on_locality) override { num_workers_requested += 1; if (grant_or_reject) { num_grant_or_reject_leases_requested += 1; @@ -190,8 +209,12 @@ class MockRayletClient : public WorkerLeaseInterface { // Trigger reply to RequestWorkerLease. bool GrantWorkerLease( - const std::string &address, int port, const NodeID &retry_at_raylet_id, - bool cancel = false, std::string worker_id = std::string(), bool reject = false, + const std::string &address, + int port, + const NodeID &retry_at_raylet_id, + bool cancel = false, + std::string worker_id = std::string(), + bool reject = false, const rpc::RequestWorkerLeaseReply::SchedulingFailureType &failure_type = rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_INTENDED) { rpc::RequestWorkerLeaseReply reply; @@ -527,9 +550,18 @@ TEST(DirectTaskTransportTest, TestLocalityAwareSubmitOneTask) { auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); lease_policy->is_locality_aware = true; - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, actor_creator, JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task = BuildEmptyTaskSpec(); @@ -569,9 +601,18 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, actor_creator, JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task = BuildEmptyTaskSpec(); @@ -611,9 +652,18 @@ TEST(DirectTaskTransportTest, TestRetryTaskApplicationLevelError) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, actor_creator, JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task = BuildEmptyTaskSpec(); task.GetMutableMessage().set_retry_exceptions(true); @@ -658,9 +708,18 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, actor_creator, JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task = BuildEmptyTaskSpec(); ASSERT_TRUE(submitter.SubmitTask(task).ok()); @@ -690,10 +749,20 @@ TEST(DirectTaskTransportTest, TestHandleRuntimeEnvSetupFailed) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, - actor_creator, JobID::Nil(), absl::nullopt, 2); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil(), + absl::nullopt, + 2); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -710,7 +779,12 @@ TEST(DirectTaskTransportTest, TestHandleRuntimeEnvSetupFailed) { // Fail task1 which will fail all the tasks ASSERT_TRUE(raylet_client->GrantWorkerLease( - "", 0, NodeID::Nil(), true, "", false, + "", + 0, + NodeID::Nil(), + true, + "", + false, /*failure_type=*/ rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_RUNTIME_ENV_SETUP_FAILED)); ASSERT_EQ(worker_client->callbacks.size(), 0); @@ -719,7 +793,12 @@ TEST(DirectTaskTransportTest, TestHandleRuntimeEnvSetupFailed) { // Fail task2 ASSERT_TRUE(raylet_client->GrantWorkerLease( - "", 0, NodeID::Nil(), true, "", false, + "", + 0, + NodeID::Nil(), + true, + "", + false, /*failure_type=*/ rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_RUNTIME_ENV_SETUP_FAILED)); ASSERT_EQ(worker_client->callbacks.size(), 0); @@ -741,10 +820,20 @@ TEST(DirectTaskTransportTest, TestWorkerHandleLocalRayletDied) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, - actor_creator, JobID::Nil(), absl::nullopt, 2); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil(), + absl::nullopt, + 2); TaskSpecification task1 = BuildEmptyTaskSpec(); ASSERT_TRUE(submitter.SubmitTask(task1).ok()); @@ -761,10 +850,20 @@ TEST(DirectTaskTransportTest, TestDriverHandleLocalRayletDied) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::DRIVER, kLongTimeout, - actor_creator, JobID::Nil(), absl::nullopt, 2); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::DRIVER, + kLongTimeout, + actor_creator, + JobID::Nil(), + absl::nullopt, + 2); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -806,10 +905,20 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, - actor_creator, JobID::Nil(), absl::nullopt, 2); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil(), + absl::nullopt, + 2); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -876,10 +985,20 @@ TEST(DirectTaskTransportTest, TestSubmitMultipleTasks) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, - actor_creator, JobID::Nil(), absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil(), + absl::nullopt, + 1); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -939,10 +1058,20 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, - actor_creator, JobID::Nil(), absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil(), + absl::nullopt, + 1); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -1003,10 +1132,20 @@ TEST(DirectTaskTransportTest, TestRetryLeaseCancellation) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, - actor_creator, JobID::Nil(), absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil(), + absl::nullopt, + 1); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); TaskSpecification task3 = BuildEmptyTaskSpec(); @@ -1062,9 +1201,18 @@ TEST(DirectTaskTransportTest, TestConcurrentCancellationAndSubmission) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, actor_creator, JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); TaskSpecification task3 = BuildEmptyTaskSpec(); @@ -1117,10 +1265,20 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, - actor_creator, JobID::Nil(), absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil(), + absl::nullopt, + 1); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); @@ -1164,9 +1322,18 @@ TEST(DirectTaskTransportTest, TestWorkerNotReturnedOnExit) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, actor_creator, JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task1 = BuildEmptyTaskSpec(); ASSERT_TRUE(submitter.SubmitTask(task1).ok()); @@ -1210,10 +1377,18 @@ TEST(DirectTaskTransportTest, TestSpillback) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, lease_client_factory, lease_policy, store, - task_finisher, NodeID::Nil(), WorkerType::WORKER, kLongTimeout, actor_creator, - JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + lease_client_factory, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task = BuildEmptyTaskSpec(); ASSERT_TRUE(submitter.SubmitTask(task).ok()); @@ -1275,10 +1450,18 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) { auto local_raylet_id = NodeID::FromRandom(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(local_raylet_id); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, lease_client_factory, lease_policy, store, - task_finisher, local_raylet_id, WorkerType::WORKER, kLongTimeout, actor_creator, - JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + lease_client_factory, + lease_policy, + store, + task_finisher, + local_raylet_id, + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task = BuildEmptyTaskSpec(); ASSERT_TRUE(submitter.SubmitTask(task).ok()); @@ -1299,8 +1482,8 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) { ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); ASSERT_FALSE(raylet_client->GrantWorkerLease("remote", 1234, NodeID::Nil())); // Trigger a rejection back to the local node. - ASSERT_TRUE(remote_lease_clients[7777]->GrantWorkerLease("local", 1234, local_raylet_id, - false, "", /*reject=*/true)); + ASSERT_TRUE(remote_lease_clients[7777]->GrantWorkerLease( + "local", 1234, local_raylet_id, false, "", /*reject=*/true)); // We should not have created another lease client to the local raylet. ASSERT_EQ(remote_lease_clients.size(), 1); // There should be no more callbacks on the remote node. @@ -1333,7 +1516,8 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) { // Helper to run a test that checks that 'same1' and 'same2' are treated as the same // resource shape, while 'different' is treated as a separate shape. void TestSchedulingKey(const std::shared_ptr store, - const TaskSpecification &same1, const TaskSpecification &same2, + const TaskSpecification &same1, + const TaskSpecification &same2, const TaskSpecification &different) { rpc::Address address; auto raylet_client = std::make_shared(); @@ -1343,10 +1527,20 @@ void TestSchedulingKey(const std::shared_ptr store, auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, - actor_creator, JobID::Nil(), absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil(), + absl::nullopt, + 1); ASSERT_TRUE(submitter.SubmitTask(same1).ok()); ASSERT_TRUE(submitter.SubmitTask(same2).ok()); @@ -1407,25 +1601,29 @@ TEST(DirectTaskTransportTest, TestSchedulingKeys) { // Tasks with different resources should request different worker leases. RAY_LOG(INFO) << "Test different resources"; - TestSchedulingKey(store, BuildTaskSpec(resources1, descriptor1), + TestSchedulingKey(store, + BuildTaskSpec(resources1, descriptor1), BuildTaskSpec(resources1, descriptor1), BuildTaskSpec(resources2, descriptor1)); // Tasks with different functions should request different worker leases. RAY_LOG(INFO) << "Test different functions"; - TestSchedulingKey(store, BuildTaskSpec(resources1, descriptor1), + TestSchedulingKey(store, + BuildTaskSpec(resources1, descriptor1), BuildTaskSpec(resources1, descriptor1), BuildTaskSpec(resources1, descriptor2)); // Tasks with different depths should request different worker leases. RAY_LOG(INFO) << "Test different depths"; - TestSchedulingKey(store, BuildTaskSpec(resources1, descriptor1, 0), + TestSchedulingKey(store, + BuildTaskSpec(resources1, descriptor1, 0), BuildTaskSpec(resources1, descriptor1, 0), BuildTaskSpec(resources1, descriptor1, 1)); // Tasks with different runtime envs do not request different workers. RAY_LOG(INFO) << "Test different runtimes"; - TestSchedulingKey(store, BuildTaskSpec(resources1, descriptor1, 0, "a"), + TestSchedulingKey(store, + BuildTaskSpec(resources1, descriptor1, 0, "a"), BuildTaskSpec(resources1, descriptor1, 0, "b"), BuildTaskSpec(resources1, descriptor1, 1, "a")); @@ -1483,10 +1681,20 @@ TEST(DirectTaskTransportTest, TestBacklogReport) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, - actor_creator, JobID::Nil(), absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil(), + absl::nullopt, + 1); TaskSpecification task1 = BuildEmptyTaskSpec(); @@ -1543,10 +1751,20 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, - /*lease_timeout_ms=*/5, actor_creator, JobID::Nil(), absl::nullopt, 1); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + /*lease_timeout_ms=*/5, + actor_creator, + JobID::Nil(), + absl::nullopt, + 1); TaskSpecification task1 = BuildEmptyTaskSpec(); TaskSpecification task2 = BuildEmptyTaskSpec(); TaskSpecification task3 = BuildEmptyTaskSpec(); @@ -1601,9 +1819,18 @@ TEST(DirectTaskTransportTest, TestKillExecutingTask) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, actor_creator, JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task = BuildEmptyTaskSpec(); ASSERT_TRUE(submitter.SubmitTask(task).ok()); @@ -1653,9 +1880,18 @@ TEST(DirectTaskTransportTest, TestKillPendingTask) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, actor_creator, JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task = BuildEmptyTaskSpec(); ASSERT_TRUE(submitter.SubmitTask(task).ok()); @@ -1687,9 +1923,18 @@ TEST(DirectTaskTransportTest, TestKillResolvingTask) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); auto lease_policy = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, - NodeID::Nil(), WorkerType::WORKER, kLongTimeout, actor_creator, JobID::Nil()); + CoreWorkerDirectTaskSubmitter submitter(address, + raylet_client, + client_pool, + nullptr, + lease_policy, + store, + task_finisher, + NodeID::Nil(), + WorkerType::WORKER, + kLongTimeout, + actor_creator, + JobID::Nil()); TaskSpecification task = BuildEmptyTaskSpec(); ObjectID obj1 = ObjectID::FromRandom(); task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); diff --git a/src/ray/core_worker/test/memory_store_test.cc b/src/ray/core_worker/test/memory_store_test.cc index 494d49923..83148a0a6 100644 --- a/src/ray/core_worker/test/memory_store_test.cc +++ b/src/ray/core_worker/test/memory_store_test.cc @@ -185,12 +185,14 @@ TEST(TestMemoryStore, TestObjectAllocator) { return std::make_shared(mock_buffer_manager, data); }; - return std::make_shared(object.GetMetadata(), object.GetNestedRefs(), - std::move(data_factory), /*copy_data=*/true); + return std::make_shared(object.GetMetadata(), + object.GetNestedRefs(), + std::move(data_factory), + /*copy_data=*/true); }; std::shared_ptr memory_store = - std::make_shared(nullptr, nullptr, nullptr, nullptr, - std::move(my_object_allocator)); + std::make_shared( + nullptr, nullptr, nullptr, nullptr, std::move(my_object_allocator)); const int32_t max_rounds = 1000; const std::string hello = "hello"; for (auto i = 0; i < max_rounds; ++i) { diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 8365cd431..be1697d76 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -33,8 +33,10 @@ namespace core { /// for more details on how this class is used. class MockWorker { public: - MockWorker(const std::string &store_socket, const std::string &raylet_socket, - int node_manager_port, const gcs::GcsClientOptions &gcs_options, + MockWorker(const std::string &store_socket, + const std::string &raylet_socket, + int node_manager_port, + const gcs::GcsClientOptions &gcs_options, StartupToken startup_token) { CoreWorkerOptions options; options.worker_type = WorkerType::WORKER; @@ -58,7 +60,8 @@ class MockWorker { void RunTaskExecutionLoop() { CoreWorkerProcess::RunTaskExecutionLoop(); } private: - Status ExecuteTask(TaskType task_type, const std::string task_name, + Status ExecuteTask(TaskType task_type, + const std::string task_name, const RayFunction &ray_function, const std::unordered_map &required_resources, const std::vector> &args, @@ -95,8 +98,8 @@ class MockWorker { const_cast(reinterpret_cast(pid_string.data())); auto memory_buffer = std::make_shared(data, pid_string.size(), true); - results->push_back(std::make_shared(memory_buffer, nullptr, - std::vector())); + results->push_back(std::make_shared( + memory_buffer, nullptr, std::vector())); return Status::OK(); } @@ -156,8 +159,8 @@ int main(int argc, char **argv) { auto startup_token = std::stoi(startup_token_str.substr(start)); ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, ""); - ray::core::MockWorker worker(store_socket, raylet_socket, node_manager_port, - gcs_options, startup_token); + ray::core::MockWorker worker( + store_socket, raylet_socket, node_manager_port, gcs_options, startup_token); worker.RunTaskExecutionLoop(); return 0; } diff --git a/src/ray/core_worker/test/object_recovery_manager_test.cc b/src/ray/core_worker/test/object_recovery_manager_test.cc index 951684ee1..22b404da4 100644 --- a/src/ray/core_worker/test/object_recovery_manager_test.cc +++ b/src/ray/core_worker/test/object_recovery_manager_test.cc @@ -59,7 +59,8 @@ class MockTaskResubmitter : public TaskResubmissionInterface { class MockRayletClient : public PinObjectsInterface { public: void PinObjectIDs( - const rpc::Address &caller_address, const std::vector &object_ids, + const rpc::Address &caller_address, + const std::vector &object_ids, const rpc::ClientCallback &callback) override { RAY_LOG(INFO) << "PinObjectIDs " << object_ids.size(); callbacks.push_back(callback); @@ -115,7 +116,9 @@ class ObjectRecoveryManagerTestBase : public ::testing::Test { raylet_client_(std::make_shared()), task_resubmitter_(std::make_shared()), ref_counter_(std::make_shared( - rpc::Address(), publisher_.get(), subscriber_.get(), + rpc::Address(), + publisher_.get(), + subscriber_.get(), [](const NodeID &node_id) { return true; }, /*lineage_pinning_enabled=*/lineage_enabled)), manager_( @@ -126,7 +129,9 @@ class ObjectRecoveryManagerTestBase : public ::testing::Test { object_directory_->AsyncGetLocations(object_id, callback); return Status::OK(); }, - task_resubmitter_, ref_counter_, memory_store_, + task_resubmitter_, + ref_counter_, + memory_store_, [&](const ObjectID &object_id, rpc::ErrorType reason, bool pin_object) { RAY_CHECK(failed_reconstructions_.count(object_id) == 0); failed_reconstructions_[object_id] = reason; @@ -171,7 +176,12 @@ class ObjectRecoveryManagerTest : public ObjectRecoveryManagerTestBase { TEST_F(ObjectRecoveryLineageDisabledTest, TestNoReconstruction) { // Lineage recording disabled. ObjectID object_id = ObjectID::FromRandom(); - ref_counter_->AddOwnedObject(object_id, {}, rpc::Address(), "", 0, true, + ref_counter_->AddOwnedObject(object_id, + {}, + rpc::Address(), + "", + 0, + true, /*add_local_ref=*/true); ASSERT_TRUE(manager_.RecoverObject(object_id)); ASSERT_TRUE(failed_reconstructions_.empty()); @@ -194,7 +204,12 @@ TEST_F(ObjectRecoveryLineageDisabledTest, TestNoReconstruction) { TEST_F(ObjectRecoveryLineageDisabledTest, TestPinNewCopy) { ObjectID object_id = ObjectID::FromRandom(); - ref_counter_->AddOwnedObject(object_id, {}, rpc::Address(), "", 0, true, + ref_counter_->AddOwnedObject(object_id, + {}, + rpc::Address(), + "", + 0, + true, /*add_local_ref=*/true); std::vector addresses({rpc::Address()}); object_directory_->SetLocations(object_id, addresses); @@ -208,7 +223,12 @@ TEST_F(ObjectRecoveryLineageDisabledTest, TestPinNewCopy) { TEST_F(ObjectRecoveryManagerTest, TestPinNewCopy) { ObjectID object_id = ObjectID::FromRandom(); - ref_counter_->AddOwnedObject(object_id, {}, rpc::Address(), "", 0, true, + ref_counter_->AddOwnedObject(object_id, + {}, + rpc::Address(), + "", + 0, + true, /*add_local_ref=*/true); std::vector addresses({rpc::Address()}); object_directory_->SetLocations(object_id, addresses); @@ -222,7 +242,12 @@ TEST_F(ObjectRecoveryManagerTest, TestPinNewCopy) { TEST_F(ObjectRecoveryManagerTest, TestReconstruction) { ObjectID object_id = ObjectID::FromRandom(); - ref_counter_->AddOwnedObject(object_id, {}, rpc::Address(), "", 0, true, + ref_counter_->AddOwnedObject(object_id, + {}, + rpc::Address(), + "", + 0, + true, /*add_local_ref=*/true); task_resubmitter_->AddTask(object_id.TaskId(), {}); @@ -235,7 +260,12 @@ TEST_F(ObjectRecoveryManagerTest, TestReconstruction) { TEST_F(ObjectRecoveryManagerTest, TestReconstructionSuppression) { ObjectID object_id = ObjectID::FromRandom(); - ref_counter_->AddOwnedObject(object_id, {}, rpc::Address(), "", 0, true, + ref_counter_->AddOwnedObject(object_id, + {}, + rpc::Address(), + "", + 0, + true, /*add_local_ref=*/true); ref_counter_->AddLocalReference(object_id, ""); @@ -272,7 +302,12 @@ TEST_F(ObjectRecoveryManagerTest, TestReconstructionChain) { std::vector dependencies; for (int i = 0; i < 3; i++) { ObjectID object_id = ObjectID::FromRandom(); - ref_counter_->AddOwnedObject(object_id, {}, rpc::Address(), "", 0, true, + ref_counter_->AddOwnedObject(object_id, + {}, + rpc::Address(), + "", + 0, + true, /*add_local_ref=*/true); task_resubmitter_->AddTask(object_id.TaskId(), dependencies); dependencies = {object_id}; @@ -290,7 +325,12 @@ TEST_F(ObjectRecoveryManagerTest, TestReconstructionChain) { TEST_F(ObjectRecoveryManagerTest, TestReconstructionFails) { ObjectID object_id = ObjectID::FromRandom(); - ref_counter_->AddOwnedObject(object_id, {}, rpc::Address(), "", 0, true, + ref_counter_->AddOwnedObject(object_id, + {}, + rpc::Address(), + "", + 0, + true, /*add_local_ref=*/true); ASSERT_TRUE(manager_.RecoverObject(object_id)); @@ -303,11 +343,21 @@ TEST_F(ObjectRecoveryManagerTest, TestReconstructionFails) { TEST_F(ObjectRecoveryManagerTest, TestDependencyReconstructionFails) { ObjectID dep_id = ObjectID::FromRandom(); - ref_counter_->AddOwnedObject(dep_id, {}, rpc::Address(), "", 0, true, + ref_counter_->AddOwnedObject(dep_id, + {}, + rpc::Address(), + "", + 0, + true, /*add_local_ref=*/true); ObjectID object_id = ObjectID::FromRandom(); - ref_counter_->AddOwnedObject(object_id, {}, rpc::Address(), "", 0, true, + ref_counter_->AddOwnedObject(object_id, + {}, + rpc::Address(), + "", + 0, + true, /*add_local_ref=*/true); task_resubmitter_->AddTask(object_id.TaskId(), {dep_id}); RAY_LOG(INFO) << object_id; @@ -324,7 +374,12 @@ TEST_F(ObjectRecoveryManagerTest, TestDependencyReconstructionFails) { TEST_F(ObjectRecoveryManagerTest, TestLineageEvicted) { ObjectID object_id = ObjectID::FromRandom(); - ref_counter_->AddOwnedObject(object_id, {}, rpc::Address(), "", 0, true, + ref_counter_->AddOwnedObject(object_id, + {}, + rpc::Address(), + "", + 0, + true, /*add_local_ref=*/true); ref_counter_->AddLocalReference(object_id, ""); ref_counter_->EvictLineage(1); diff --git a/src/ray/core_worker/test/scheduling_queue_test.cc b/src/ray/core_worker/test/scheduling_queue_test.cc index 5348d73d3..aa0731e11 100644 --- a/src/ray/core_worker/test/scheduling_queue_test.cc +++ b/src/ray/core_worker/test/scheduling_queue_test.cc @@ -27,15 +27,22 @@ class MockActorSchedulingQueue { MockActorSchedulingQueue(instrumented_io_context &main_io_service, DependencyWaiter &waiter) : queue_(main_io_service, waiter) {} - void Add(int64_t seq_no, int64_t client_processed_up_to, + void Add(int64_t seq_no, + int64_t client_processed_up_to, std::function accept_request, std::function reject_request, rpc::SendReplyCallback send_reply_callback = nullptr, TaskID task_id = TaskID::Nil(), const std::vector &dependencies = {}) { - queue_.Add(seq_no, client_processed_up_to, std::move(accept_request), - std::move(reject_request), send_reply_callback, "", - FunctionDescriptorBuilder::Empty(), task_id, dependencies); + queue_.Add(seq_no, + client_processed_up_to, + std::move(accept_request), + std::move(reject_request), + send_reply_callback, + "", + FunctionDescriptorBuilder::Empty(), + task_id, + dependencies); } private: diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index cdcf3557d..460cce1b0 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -45,11 +45,14 @@ class TaskManagerTest : public ::testing::Test { publisher_(std::make_shared()), subscriber_(std::make_shared()), reference_counter_(std::shared_ptr(new ReferenceCounter( - rpc::Address(), publisher_.get(), subscriber_.get(), + rpc::Address(), + publisher_.get(), + subscriber_.get(), [this](const NodeID &node_id) { return all_nodes_alive_; }, lineage_pinning_enabled))), manager_( - store_, reference_counter_, + store_, + reference_counter_, [this](const RayObject &object, const ObjectID &object_id) { stored_in_plasma.insert(object_id); }, @@ -57,7 +60,8 @@ class TaskManagerTest : public ::testing::Test { num_retries_++; return Status::OK(); }, - [](const JobID &job_id, const std::string &type, + [](const JobID &job_id, + const std::string &type, const std::string &error_message, double timestamp) { return Status::OK(); }, max_lineage_bytes) {} @@ -114,7 +118,8 @@ TEST_F(TaskManagerTest, TestTaskSuccess) { RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); ASSERT_EQ(results.size(), 1); ASSERT_FALSE(results[0]->IsException()); - ASSERT_EQ(std::memcmp(results[0]->GetData()->Data(), return_object->data().data(), + ASSERT_EQ(std::memcmp(results[0]->GetData()->Data(), + return_object->data().data(), return_object->data().size()), 0); ASSERT_EQ(num_retries_, 0); diff --git a/src/ray/core_worker/transport/actor_scheduling_queue.cc b/src/ray/core_worker/transport/actor_scheduling_queue.cc index da4392323..6088707f5 100644 --- a/src/ray/core_worker/transport/actor_scheduling_queue.cc +++ b/src/ray/core_worker/transport/actor_scheduling_queue.cc @@ -18,10 +18,13 @@ namespace ray { namespace core { ActorSchedulingQueue::ActorSchedulingQueue( - instrumented_io_context &main_io_service, DependencyWaiter &waiter, + instrumented_io_context &main_io_service, + DependencyWaiter &waiter, std::shared_ptr> pool_manager, - bool is_asyncio, int fiber_max_concurrency, - const std::vector &concurrency_groups, int64_t reorder_wait_seconds) + bool is_asyncio, + int fiber_max_concurrency, + const std::vector &concurrency_groups, + int64_t reorder_wait_seconds) : reorder_wait_seconds_(reorder_wait_seconds), wait_timer_(main_io_service), main_thread_id_(boost::this_thread::get_id()), @@ -65,7 +68,8 @@ size_t ActorSchedulingQueue::Size() const { } /// Add a new actor task's callbacks to the worker queue. -void ActorSchedulingQueue::Add(int64_t seq_no, int64_t client_processed_up_to, +void ActorSchedulingQueue::Add(int64_t seq_no, + int64_t client_processed_up_to, std::function accept_request, std::function reject_request, rpc::SendReplyCallback send_reply_callback, @@ -84,10 +88,13 @@ void ActorSchedulingQueue::Add(int64_t seq_no, int64_t client_processed_up_to, } RAY_LOG(DEBUG) << "Enqueue " << seq_no << " cur seqno " << next_seq_no_; - pending_actor_tasks_[seq_no] = - InboundRequest(std::move(accept_request), std::move(reject_request), - std::move(send_reply_callback), task_id, dependencies.size() > 0, - concurrency_group_name, function_descriptor); + pending_actor_tasks_[seq_no] = InboundRequest(std::move(accept_request), + std::move(reject_request), + std::move(send_reply_callback), + task_id, + dependencies.size() > 0, + concurrency_group_name, + function_descriptor); if (dependencies.size() > 0) { waiter_.Wait(dependencies, [seq_no, this]() { diff --git a/src/ray/core_worker/transport/actor_scheduling_queue.h b/src/ray/core_worker/transport/actor_scheduling_queue.h index af50dbeb5..1b8593f56 100644 --- a/src/ray/core_worker/transport/actor_scheduling_queue.h +++ b/src/ray/core_worker/transport/actor_scheduling_queue.h @@ -40,10 +40,12 @@ const int kMaxReorderWaitSeconds = 30; class ActorSchedulingQueue : public SchedulingQueue { public: ActorSchedulingQueue( - instrumented_io_context &main_io_service, DependencyWaiter &waiter, + instrumented_io_context &main_io_service, + DependencyWaiter &waiter, std::shared_ptr> pool_manager = std::make_shared>(), - bool is_asyncio = false, int fiber_max_concurrency = 1, + bool is_asyncio = false, + int fiber_max_concurrency = 1, const std::vector &concurrency_groups = {}, int64_t reorder_wait_seconds = kMaxReorderWaitSeconds); @@ -54,7 +56,8 @@ class ActorSchedulingQueue : public SchedulingQueue { size_t Size() const override; /// Add a new actor task's callbacks to the worker queue. - void Add(int64_t seq_no, int64_t client_processed_up_to, + void Add(int64_t seq_no, + int64_t client_processed_up_to, std::function accept_request, std::function reject_request, rpc::SendReplyCallback send_reply_callback, diff --git a/src/ray/core_worker/transport/actor_scheduling_util.cc b/src/ray/core_worker/transport/actor_scheduling_util.cc index 72cef822b..a5117b9a4 100644 --- a/src/ray/core_worker/transport/actor_scheduling_util.cc +++ b/src/ray/core_worker/transport/actor_scheduling_util.cc @@ -22,8 +22,10 @@ InboundRequest::InboundRequest() {} InboundRequest::InboundRequest( std::function accept_callback, std::function reject_callback, - rpc::SendReplyCallback send_reply_callback, class TaskID task_id, - bool has_dependencies, const std::string &concurrency_group_name, + rpc::SendReplyCallback send_reply_callback, + class TaskID task_id, + bool has_dependencies, + const std::string &concurrency_group_name, const ray::FunctionDescriptor &function_descriptor) : accept_callback_(std::move(accept_callback)), reject_callback_(std::move(reject_callback)), diff --git a/src/ray/core_worker/transport/actor_scheduling_util.h b/src/ray/core_worker/transport/actor_scheduling_util.h index 914dfbcec..3896c9fbb 100644 --- a/src/ray/core_worker/transport/actor_scheduling_util.h +++ b/src/ray/core_worker/transport/actor_scheduling_util.h @@ -29,8 +29,10 @@ class InboundRequest { InboundRequest(); InboundRequest(std::function accept_callback, std::function reject_callback, - rpc::SendReplyCallback send_reply_callback, TaskID task_id, - bool has_dependencies, const std::string &concurrency_group_name, + rpc::SendReplyCallback send_reply_callback, + TaskID task_id, + bool has_dependencies, + const std::string &concurrency_group_name, const ray::FunctionDescriptor &function_descriptor); void Accept(); diff --git a/src/ray/core_worker/transport/dependency_resolver.cc b/src/ray/core_worker/transport/dependency_resolver.cc index da52aff65..dda741f10 100644 --- a/src/ray/core_worker/transport/dependency_resolver.cc +++ b/src/ray/core_worker/transport/dependency_resolver.cc @@ -43,7 +43,8 @@ struct TaskState { void InlineDependencies( absl::flat_hash_map> dependencies, - TaskSpecification &task, std::vector *inlined_dependency_ids, + TaskSpecification &task, + std::vector *inlined_dependency_ids, std::vector *contained_ids) { auto &msg = task.GetMutableMessage(); size_t found = 0; @@ -110,31 +111,34 @@ void LocalDependencyResolver::ResolveDependencies( for (const auto &it : state->local_dependencies) { const ObjectID &obj_id = it.first; - in_memory_store_.GetAsync(obj_id, [this, state, obj_id, - on_complete](std::shared_ptr obj) { - RAY_CHECK(obj != nullptr); - bool complete = false; - std::vector inlined_dependency_ids; - std::vector contained_ids; - { - absl::MutexLock lock(&mu_); - state->local_dependencies[obj_id] = std::move(obj); - if (--state->obj_dependencies_remaining == 0) { - InlineDependencies(state->local_dependencies, state->task, - &inlined_dependency_ids, &contained_ids); - if (state->actor_dependencies_remaining == 0) { - complete = true; - num_pending_ -= 1; + in_memory_store_.GetAsync( + obj_id, [this, state, obj_id, on_complete](std::shared_ptr obj) { + RAY_CHECK(obj != nullptr); + bool complete = false; + std::vector inlined_dependency_ids; + std::vector contained_ids; + { + absl::MutexLock lock(&mu_); + state->local_dependencies[obj_id] = std::move(obj); + if (--state->obj_dependencies_remaining == 0) { + InlineDependencies(state->local_dependencies, + state->task, + &inlined_dependency_ids, + &contained_ids); + if (state->actor_dependencies_remaining == 0) { + complete = true; + num_pending_ -= 1; + } + } } - } - } - if (inlined_dependency_ids.size() > 0) { - task_finisher_.OnTaskDependenciesInlined(inlined_dependency_ids, contained_ids); - } - if (complete) { - on_complete(state->status); - } - }); + if (inlined_dependency_ids.size() > 0) { + task_finisher_.OnTaskDependenciesInlined(inlined_dependency_ids, + contained_ids); + } + if (complete) { + on_complete(state->status); + } + }); } for (const auto &actor_id : state->actor_dependencies) { diff --git a/src/ray/core_worker/transport/direct_actor_task_submitter.cc b/src/ray/core_worker/transport/direct_actor_task_submitter.cc index 6a2118b8d..4c638abbb 100644 --- a/src/ray/core_worker/transport/direct_actor_task_submitter.cc +++ b/src/ray/core_worker/transport/direct_actor_task_submitter.cc @@ -37,7 +37,8 @@ void CoreWorkerDirectActorTaskSubmitter::AddActorQueueIfNotExists( } void CoreWorkerDirectActorTaskSubmitter::KillActor(const ActorID &actor_id, - bool force_kill, bool no_restart) { + bool force_kill, + bool no_restart) { absl::MutexLock lock(&mu_); rpc::KillActorRequest request; request.set_intended_actor_id(actor_id.Binary()); @@ -133,8 +134,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe auto status = Status::IOError("cancelling task of dead actor"); // No need to increment the number of completed tasks since the actor is // dead. - RAY_UNUSED(!task_finisher_.FailOrRetryPendingTask(task_id, error_type, &status, - &error_info)); + RAY_UNUSED(!task_finisher_.FailOrRetryPendingTask( + task_id, error_type, &status, &error_info)); } // If the task submission subsequently fails, then the client will receive @@ -224,7 +225,9 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectActor(const ActorID &actor_id, } void CoreWorkerDirectActorTaskSubmitter::DisconnectActor( - const ActorID &actor_id, int64_t num_restarts, bool dead, + const ActorID &actor_id, + int64_t num_restarts, + bool dead, const rpc::ActorDeathCause &death_cause) { RAY_LOG(DEBUG) << "Disconnecting from actor " << actor_id << ", death context type=" << GetActorDeathCauseString(death_cause); @@ -271,8 +274,8 @@ void CoreWorkerDirectActorTaskSubmitter::DisconnectActor( task_finisher_.MarkTaskCanceled(task_id); // No need to increment the number of completed tasks since the actor is // dead. - RAY_UNUSED(!task_finisher_.FailOrRetryPendingTask(task_id, error_type, &status, - &error_info)); + RAY_UNUSED(!task_finisher_.FailOrRetryPendingTask( + task_id, error_type, &status, &error_info)); } auto &wait_for_death_info_tasks = queue->second.wait_for_death_info_tasks; @@ -280,8 +283,8 @@ void CoreWorkerDirectActorTaskSubmitter::DisconnectActor( RAY_LOG(INFO) << "Failing tasks waiting for death info, size=" << wait_for_death_info_tasks.size() << ", actor_id=" << actor_id; for (auto &net_err_task : wait_for_death_info_tasks) { - RAY_UNUSED(task_finisher_.MarkTaskReturnObjectsFailed(net_err_task.second, - error_type, &error_info)); + RAY_UNUSED(task_finisher_.MarkTaskReturnObjectsFailed( + net_err_task.second, error_type, &error_info)); } // No need to clean up tasks that have been sent and are waiting for @@ -428,7 +431,10 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(ClientQueue &queue, // If the actor is already dead, immediately mark the task object is failed. // Otherwise, it will have grace period until it makrs the object is dead. will_retry = task_finisher_.FailOrRetryPendingTask( - task_id, error_type, &status, &error_info, + task_id, + error_type, + &status, + &error_info, /*mark_task_object_failed*/ is_actor_dead); if (!is_actor_dead && !will_retry) { // No retry == actor is dead. diff --git a/src/ray/core_worker/transport/direct_actor_task_submitter.h b/src/ray/core_worker/transport/direct_actor_task_submitter.h index e2e07e547..e8c96e5d8 100644 --- a/src/ray/core_worker/transport/direct_actor_task_submitter.h +++ b/src/ray/core_worker/transport/direct_actor_task_submitter.h @@ -48,9 +48,12 @@ class CoreWorkerDirectActorTaskSubmitterInterface { virtual void AddActorQueueIfNotExists(const ActorID &actor_id, int32_t max_pending_calls, bool execute_out_of_order = false) = 0; - virtual void ConnectActor(const ActorID &actor_id, const rpc::Address &address, + virtual void ConnectActor(const ActorID &actor_id, + const rpc::Address &address, int64_t num_restarts) = 0; - virtual void DisconnectActor(const ActorID &actor_id, int64_t num_restarts, bool dead, + virtual void DisconnectActor(const ActorID &actor_id, + int64_t num_restarts, + bool dead, const rpc::ActorDeathCause &death_cause) = 0; virtual void KillActor(const ActorID &actor_id, bool force_kill, bool no_restart) = 0; @@ -64,8 +67,10 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerDirectActorTaskSubmitterInterface { public: CoreWorkerDirectActorTaskSubmitter( - rpc::CoreWorkerClientPool &core_worker_client_pool, CoreWorkerMemoryStore &store, - TaskFinisherInterface &task_finisher, ActorCreatorInterface &actor_creator, + rpc::CoreWorkerClientPool &core_worker_client_pool, + CoreWorkerMemoryStore &store, + TaskFinisherInterface &task_finisher, + ActorCreatorInterface &actor_creator, std::function warn_excess_queueing, instrumented_io_context &io_service) : core_worker_client_pool_(core_worker_client_pool), @@ -84,7 +89,8 @@ class CoreWorkerDirectActorTaskSubmitter /// /// \param[in] actor_id The actor for whom to add a queue. /// \param[in] max_pending_calls The max pending calls for the actor to be added. - void AddActorQueueIfNotExists(const ActorID &actor_id, int32_t max_pending_calls, + void AddActorQueueIfNotExists(const ActorID &actor_id, + int32_t max_pending_calls, bool execute_out_of_order = false); /// Submit a task to an actor for execution. @@ -110,7 +116,8 @@ class CoreWorkerDirectActorTaskSubmitter /// \param[in] num_restarts How many times this actor has been restarted /// before. If we've already seen a later incarnation of the actor, we will /// ignore the command to connect. - void ConnectActor(const ActorID &actor_id, const rpc::Address &address, + void ConnectActor(const ActorID &actor_id, + const rpc::Address &address, int64_t num_restarts); /// Disconnect from a failed actor. @@ -122,7 +129,9 @@ class CoreWorkerDirectActorTaskSubmitter /// \param[in] dead Whether the actor is permanently dead. In this case, all /// pending tasks for the actor should be failed. /// \param[in] death_cause Context about why this actor is dead. - void DisconnectActor(const ActorID &actor_id, int64_t num_restarts, bool dead, + void DisconnectActor(const ActorID &actor_id, + int64_t num_restarts, + bool dead, const rpc::ActorDeathCause &death_cause); /// Set the timerstamp for the caller. @@ -221,7 +230,8 @@ class CoreWorkerDirectActorTaskSubmitter /// \param[in] skip_queue Whether to skip the task queue. This will send the /// task for execution immediately. /// \return Void. - void PushActorTask(ClientQueue &queue, const TaskSpecification &task_spec, + void PushActorTask(ClientQueue &queue, + const TaskSpecification &task_spec, bool skip_queue) EXCLUSIVE_LOCKS_REQUIRED(mu_); /// Send all pending tasks for an actor. diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 4489b9124..7c133544d 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -26,7 +26,8 @@ namespace ray { namespace core { void CoreWorkerDirectTaskReceiver::Init( - std::shared_ptr client_pool, rpc::Address rpc_address, + std::shared_ptr client_pool, + rpc::Address rpc_address, std::shared_ptr dependency_waiter) { waiter_ = std::move(dependency_waiter); rpc_address_ = rpc_address; @@ -34,7 +35,8 @@ void CoreWorkerDirectTaskReceiver::Init( } void CoreWorkerDirectTaskReceiver::HandleTask( - const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, + const rpc::PushTaskRequest &request, + rpc::PushTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use"; // Use `mutable_task_spec()` here as `task_spec()` returns a const reference @@ -56,7 +58,8 @@ void CoreWorkerDirectTaskReceiver::HandleTask( if (task_spec.IsActorCreationTask()) { worker_context_.SetCurrentActorId(task_spec.ActorCreationId()); - SetupActor(task_spec.IsAsyncioActor(), task_spec.MaxActorConcurrency(), + SetupActor(task_spec.IsAsyncioActor(), + task_spec.MaxActorConcurrency(), task_spec.ExecuteOutOfOrder()); } @@ -74,8 +77,8 @@ void CoreWorkerDirectTaskReceiver::HandleTask( } } - auto accept_callback = [this, reply, task_spec, - resource_ids](rpc::SendReplyCallback send_reply_callback) { + auto accept_callback = [this, reply, task_spec, resource_ids]( + rpc::SendReplyCallback send_reply_callback) { if (task_spec.GetMessage().skip_execution()) { send_reply_callback(Status::OK(), nullptr, nullptr); return; @@ -90,9 +93,11 @@ void CoreWorkerDirectTaskReceiver::HandleTask( std::vector> return_objects; bool is_application_level_error = false; - auto status = - task_handler_(task_spec, resource_ids, &return_objects, - reply->mutable_borrowed_refs(), &is_application_level_error); + auto status = task_handler_(task_spec, + resource_ids, + &return_objects, + reply->mutable_borrowed_refs(), + &is_application_level_error); reply->set_is_application_level_error(is_application_level_error); bool objects_valid = return_objects.size() == num_returns; @@ -173,33 +178,49 @@ void CoreWorkerDirectTaskReceiver::HandleTask( RAY_CHECK(cg_it != concurrency_groups_cache_.end()); if (execute_out_of_order_) { it = actor_scheduling_queues_ - .emplace( - task_spec.CallerWorkerId(), - std::unique_ptr(new OutOfOrderActorSchedulingQueue( - task_main_io_service_, *waiter_, pool_manager_, is_asyncio_, - fiber_max_concurrency_, cg_it->second))) + .emplace(task_spec.CallerWorkerId(), + std::unique_ptr( + new OutOfOrderActorSchedulingQueue(task_main_io_service_, + *waiter_, + pool_manager_, + is_asyncio_, + fiber_max_concurrency_, + cg_it->second))) .first; } else { it = actor_scheduling_queues_ .emplace(task_spec.CallerWorkerId(), - std::unique_ptr(new ActorSchedulingQueue( - task_main_io_service_, *waiter_, pool_manager_, is_asyncio_, - fiber_max_concurrency_, cg_it->second))) + std::unique_ptr( + new ActorSchedulingQueue(task_main_io_service_, + *waiter_, + pool_manager_, + is_asyncio_, + fiber_max_concurrency_, + cg_it->second))) .first; } } - it->second->Add(request.sequence_number(), request.client_processed_up_to(), - std::move(accept_callback), std::move(reject_callback), - std::move(send_reply_callback), task_spec.ConcurrencyGroupName(), - task_spec.FunctionDescriptor(), task_spec.TaskId(), dependencies); + it->second->Add(request.sequence_number(), + request.client_processed_up_to(), + std::move(accept_callback), + std::move(reject_callback), + std::move(send_reply_callback), + task_spec.ConcurrencyGroupName(), + task_spec.FunctionDescriptor(), + task_spec.TaskId(), + dependencies); } else { // Add the normal task's callbacks to the non-actor scheduling queue. - normal_scheduling_queue_->Add( - request.sequence_number(), request.client_processed_up_to(), - std::move(accept_callback), std::move(reject_callback), - std::move(send_reply_callback), "", task_spec.FunctionDescriptor(), - task_spec.TaskId(), dependencies); + normal_scheduling_queue_->Add(request.sequence_number(), + request.client_processed_up_to(), + std::move(accept_callback), + std::move(reject_callback), + std::move(send_reply_callback), + "", + task_spec.FunctionDescriptor(), + task_spec.TaskId(), + dependencies); } } @@ -220,7 +241,8 @@ bool CoreWorkerDirectTaskReceiver::CancelQueuedNormalTask(TaskID task_id) { } /// Note that this method is only used for asyncio actor. -void CoreWorkerDirectTaskReceiver::SetupActor(bool is_asyncio, int fiber_max_concurrency, +void CoreWorkerDirectTaskReceiver::SetupActor(bool is_asyncio, + int fiber_max_concurrency, bool execute_out_of_order) { RAY_CHECK(fiber_max_concurrency_ == 0) << "SetupActor should only be called at most once."; diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 0ac6422c3..99212d41d 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -69,7 +69,8 @@ class CoreWorkerDirectTaskReceiver { pool_manager_(std::make_shared>()) {} /// Initialize this receiver. This must be called prior to use. - void Init(std::shared_ptr, rpc::Address rpc_address, + void Init(std::shared_ptr, + rpc::Address rpc_address, std::shared_ptr dependency_waiter); /// Handle a `PushTask` request. If it's an actor request, this function will enqueue @@ -79,7 +80,8 @@ class CoreWorkerDirectTaskReceiver { /// \param[in] request The request message. /// \param[out] reply The reply message. /// \param[in] send_reply_callback The callback to be called when the request is done. - void HandleTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, + void HandleTask(const rpc::PushTaskRequest &request, + rpc::PushTaskReply *reply, rpc::SendReplyCallback send_reply_callback); /// Pop tasks from the queue and execute them sequentially diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index 40cb5aed6..9721d158b 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -48,8 +48,8 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { // Copy the actor's reply to the GCS for ref counting purposes. rpc::PushTaskReply push_task_reply; push_task_reply.mutable_borrowed_refs()->CopyFrom(reply.borrowed_refs()); - task_finisher_->CompletePendingTask(task_id, push_task_reply, - reply.actor_address()); + task_finisher_->CompletePendingTask( + task_id, push_task_reply, reply.actor_address()); } else { RAY_LOG(ERROR) << "Failed to create actor " << actor_id << " with status: " << status.ToString(); @@ -70,11 +70,12 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { if (keep_executing) { // Note that the dependencies in the task spec are mutated to only contain // plasma dependencies after ResolveDependencies finishes. - const SchedulingKey scheduling_key( - task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(), - task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() - : ActorID::Nil(), - task_spec.GetRuntimeEnvHash()); + const SchedulingKey scheduling_key(task_spec.GetSchedulingClass(), + task_spec.GetDependencyIds(), + task_spec.IsActorCreationTask() + ? task_spec.ActorCreationId() + : ActorID::Nil(), + task_spec.GetRuntimeEnvHash()); auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; scheduling_key_entry.task_queue.push_back(task_spec); scheduling_key_entry.resource_spec = task_spec; @@ -88,8 +89,11 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { worker_to_lease_entry_.end()); auto &lease_entry = worker_to_lease_entry_[active_worker_addr]; if (!lease_entry.is_busy) { - OnWorkerIdle(active_worker_addr, scheduling_key, /*was_error*/ false, - /*worker_exiting*/ false, lease_entry.assigned_resources); + OnWorkerIdle(active_worker_addr, + scheduling_key, + /*was_error*/ false, + /*worker_exiting*/ false, + lease_entry.assigned_resources); break; } } @@ -106,7 +110,8 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { } void CoreWorkerDirectTaskSubmitter::AddWorkerLeaseClient( - const rpc::WorkerAddress &addr, std::shared_ptr lease_client, + const rpc::WorkerAddress &addr, + std::shared_ptr lease_client, const google::protobuf::RepeatedPtrField &assigned_resources, const SchedulingKey &scheduling_key) { client_cache_->GetOrConnect(addr.ToProto()); @@ -121,7 +126,8 @@ void CoreWorkerDirectTaskSubmitter::AddWorkerLeaseClient( } void CoreWorkerDirectTaskSubmitter::ReturnWorker(const rpc::WorkerAddress addr, - bool was_error, bool worker_exiting, + bool was_error, + bool worker_exiting, const SchedulingKey &scheduling_key) { RAY_LOG(DEBUG) << "Returning worker " << addr.worker_id << " to raylet " << addr.raylet_id; @@ -140,8 +146,8 @@ void CoreWorkerDirectTaskSubmitter::ReturnWorker(const rpc::WorkerAddress addr, scheduling_key_entries_.erase(scheduling_key); } - auto status = lease_entry.lease_client->ReturnWorker(addr.port, addr.worker_id, - was_error, worker_exiting); + auto status = lease_entry.lease_client->ReturnWorker( + addr.port, addr.worker_id, was_error, worker_exiting); if (!status.ok()) { RAY_LOG(ERROR) << "Error returning worker to raylet: " << status.ToString(); } @@ -149,7 +155,9 @@ void CoreWorkerDirectTaskSubmitter::ReturnWorker(const rpc::WorkerAddress addr, } void CoreWorkerDirectTaskSubmitter::OnWorkerIdle( - const rpc::WorkerAddress &addr, const SchedulingKey &scheduling_key, bool was_error, + const rpc::WorkerAddress &addr, + const SchedulingKey &scheduling_key, + bool was_error, bool worker_exiting, const google::protobuf::RepeatedPtrField &assigned_resources) { auto &lease_entry = worker_to_lease_entry_[addr]; @@ -211,8 +219,9 @@ void CoreWorkerDirectTaskSubmitter::CancelWorkerLeaseIfNeeded( auto &task_id = pending_lease_request.first; RAY_LOG(DEBUG) << "Canceling lease request " << task_id; lease_client->CancelWorkerLease( - task_id, [this, scheduling_key](const Status &status, - const rpc::CancelWorkerLeaseReply &reply) { + task_id, + [this, scheduling_key](const Status &status, + const rpc::CancelWorkerLeaseReply &reply) { absl::MutexLock lock(&mu_); if (status.ok() && !reply.success()) { // The cancellation request can fail if the raylet does not have @@ -241,8 +250,9 @@ CoreWorkerDirectTaskSubmitter::GetOrConnectLeaseClient( if (it == remote_lease_clients_.end()) { RAY_LOG(INFO) << "Connecting to raylet " << raylet_id; it = remote_lease_clients_ - .emplace(raylet_id, lease_client_factory_(raylet_address->ip_address(), - raylet_address->port())) + .emplace(raylet_id, + lease_client_factory_(raylet_address->ip_address(), + raylet_address->port())) .first; } lease_client = it->second; @@ -389,8 +399,10 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( error_info.mutable_runtime_env_setup_failed_error()->set_error_message( reply.scheduling_failure_message()); RAY_UNUSED(task_finisher_->FailPendingTask( - task_spec.TaskId(), rpc::ErrorType::RUNTIME_ENV_SETUP_FAILED, - /*status*/ nullptr, &error_info)); + task_spec.TaskId(), + rpc::ErrorType::RUNTIME_ENV_SETUP_FAILED, + /*status*/ nullptr, + &error_info)); } else { if (task_spec.IsActorCreationTask()) { RAY_UNUSED(task_finisher_->FailPendingTask( @@ -427,11 +439,14 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( auto resources_copy = reply.resource_mapping(); - AddWorkerLeaseClient(addr, std::move(lease_client), resources_copy, - scheduling_key); + AddWorkerLeaseClient( + addr, std::move(lease_client), resources_copy, scheduling_key); RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1); - OnWorkerIdle(addr, scheduling_key, - /*error=*/false, /*worker_exiting=*/false, resources_copy); + OnWorkerIdle(addr, + scheduling_key, + /*error=*/false, + /*worker_exiting=*/false, + resources_copy); } else { // The raylet redirected us to a different raylet to retry at. RAY_CHECK(!is_spillback); @@ -484,14 +499,17 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( } } }, - task_queue.size(), is_selected_based_on_locality); + task_queue.size(), + is_selected_based_on_locality); scheduling_key_entry.pending_lease_requests.emplace(task_id, *raylet_address); ReportWorkerBacklogIfNeeded(scheduling_key); } void CoreWorkerDirectTaskSubmitter::PushNormalTask( - const rpc::WorkerAddress &addr, rpc::CoreWorkerClientInterface &client, - const SchedulingKey &scheduling_key, const TaskSpecification &task_spec, + const rpc::WorkerAddress &addr, + rpc::CoreWorkerClientInterface &client, + const SchedulingKey &scheduling_key, + const TaskSpecification &task_spec, const google::protobuf::RepeatedPtrField &assigned_resources) { RAY_LOG(DEBUG) << "Pushing task " << task_spec.TaskId() << " to worker " << addr.worker_id << " of raylet " << addr.raylet_id; @@ -508,7 +526,13 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask( request->set_intended_worker_id(addr.worker_id.Binary()); client.PushNormalTask( std::move(request), - [this, task_spec, task_id, is_actor, is_actor_creation, scheduling_key, addr, + [this, + task_spec, + task_id, + is_actor, + is_actor_creation, + scheduling_key, + addr, assigned_resources](Status status, const rpc::PushTaskReply &reply) { { RAY_LOG(DEBUG) << "Task " << task_id << " finished from worker " @@ -530,9 +554,11 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask( if (!status.ok() || !is_actor_creation || reply.worker_exiting()) { // Successful actor creation leases the worker indefinitely from the raylet. - OnWorkerIdle(addr, scheduling_key, + OnWorkerIdle(addr, + scheduling_key, /*error=*/!status.ok(), - /*worker_exiting=*/reply.worker_exiting(), assigned_resources); + /*worker_exiting=*/reply.worker_exiting(), + assigned_resources); } } if (!status.ok()) { @@ -555,11 +581,13 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask( } Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, - bool force_kill, bool recursive) { + bool force_kill, + bool recursive) { RAY_LOG(INFO) << "Cancelling a task: " << task_spec.TaskId() << " force_kill: " << force_kill << " recursive: " << recursive; const SchedulingKey scheduling_key( - task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(), + task_spec.GetSchedulingClass(), + task_spec.GetDependencyIds(), task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() : ActorID::Nil(), task_spec.GetRuntimeEnvHash()); std::shared_ptr client = nullptr; @@ -621,8 +649,9 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, request.set_force_kill(force_kill); request.set_recursive(recursive); client->CancelTask( - request, [this, task_spec, scheduling_key, force_kill, recursive]( - const Status &status, const rpc::CancelTaskReply &reply) { + request, + [this, task_spec, scheduling_key, force_kill, recursive]( + const Status &status, const rpc::CancelTaskReply &reply) { absl::MutexLock lock(&mu_); cancelled_tasks_.erase(task_spec.TaskId()); @@ -634,8 +663,11 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, RayConfig::instance().cancellation_retry_ms())); } cancel_retry_timer_->async_wait( - boost::bind(&CoreWorkerDirectTaskSubmitter::CancelTask, this, task_spec, - force_kill, recursive)); + boost::bind(&CoreWorkerDirectTaskSubmitter::CancelTask, + this, + task_spec, + force_kill, + recursive)); } } // Retry is not attempted if !status.ok() because force-kill may kill the worker @@ -646,7 +678,8 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, Status CoreWorkerDirectTaskSubmitter::CancelRemoteTask(const ObjectID &object_id, const rpc::Address &worker_addr, - bool force_kill, bool recursive) { + bool force_kill, + bool recursive) { auto maybe_client = client_cache_->GetByID(rpc::WorkerAddress(worker_addr).worker_id); if (!maybe_client.has_value()) { diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 88bf24977..0447ff8d7 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -57,14 +57,18 @@ using SchedulingKey = class CoreWorkerDirectTaskSubmitter { public: explicit CoreWorkerDirectTaskSubmitter( - rpc::Address rpc_address, std::shared_ptr lease_client, + rpc::Address rpc_address, + std::shared_ptr lease_client, std::shared_ptr core_worker_client_pool, LeaseClientFactoryFn lease_client_factory, std::shared_ptr lease_policy, std::shared_ptr store, - std::shared_ptr task_finisher, NodeID local_raylet_id, - WorkerType worker_type, int64_t lease_timeout_ms, - std::shared_ptr actor_creator, const JobID &job_id, + std::shared_ptr task_finisher, + NodeID local_raylet_id, + WorkerType worker_type, + int64_t lease_timeout_ms, + std::shared_ptr actor_creator, + const JobID &job_id, absl::optional cancel_timer = absl::nullopt, uint64_t max_pending_lease_requests_per_scheduling_category = ::RayConfig::instance().max_pending_lease_requests_per_scheduling_category()) @@ -95,8 +99,10 @@ class CoreWorkerDirectTaskSubmitter { /// \param[in] force_kill Whether to kill the worker executing the task. Status CancelTask(TaskSpecification task_spec, bool force_kill, bool recursive); - Status CancelRemoteTask(const ObjectID &object_id, const rpc::Address &worker_addr, - bool force_kill, bool recursive); + Status CancelRemoteTask(const ObjectID &object_id, + const rpc::Address &worker_addr, + bool force_kill, + bool recursive); /// Check that the scheduling_key_entries_ hashmap is empty by calling the private /// CheckNoSchedulingKeyEntries function after acquiring the lock. bool CheckNoSchedulingKeyEntriesPublic() { @@ -127,7 +133,9 @@ class CoreWorkerDirectTaskSubmitter { /// \param[in] worker_exiting Whether the worker is exiting. /// \param[in] assigned_resources Resource ids previously assigned to the worker. void OnWorkerIdle( - const rpc::WorkerAddress &addr, const SchedulingKey &task_queue_key, bool was_error, + const rpc::WorkerAddress &addr, + const SchedulingKey &task_queue_key, + bool was_error, bool worker_exiting, const google::protobuf::RepeatedPtrField &assigned_resources) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -163,7 +171,8 @@ class CoreWorkerDirectTaskSubmitter { /// Set up client state for newly granted worker lease. void AddWorkerLeaseClient( - const rpc::WorkerAddress &addr, std::shared_ptr lease_client, + const rpc::WorkerAddress &addr, + std::shared_ptr lease_client, const google::protobuf::RepeatedPtrField &assigned_resources, const SchedulingKey &scheduling_key) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -171,7 +180,9 @@ class CoreWorkerDirectTaskSubmitter { /// \param[in] addr The address of the worker. /// \param[in] was_error Whether the task failed to be submitted. /// \param[in] worker_exiting Whether the worker is exiting. - void ReturnWorker(const rpc::WorkerAddress addr, bool was_error, bool worker_exiting, + void ReturnWorker(const rpc::WorkerAddress addr, + bool was_error, + bool worker_exiting, const SchedulingKey &scheduling_key) EXCLUSIVE_LOCKS_REQUIRED(mu_); /// Check that the scheduling_key_entries_ hashmap is empty. @@ -254,8 +265,8 @@ class CoreWorkerDirectTaskSubmitter { int64_t lease_expiration_time = 0, google::protobuf::RepeatedPtrField assigned_resources = google::protobuf::RepeatedPtrField(), - SchedulingKey scheduling_key = std::make_tuple(0, std::vector(), - ActorID::Nil(), 0)) + SchedulingKey scheduling_key = + std::make_tuple(0, std::vector(), ActorID::Nil(), 0)) : lease_client(lease_client), lease_expiration_time(lease_expiration_time), assigned_resources(assigned_resources), diff --git a/src/ray/core_worker/transport/normal_scheduling_queue.cc b/src/ray/core_worker/transport/normal_scheduling_queue.cc index d189ea96a..e46af5f29 100644 --- a/src/ray/core_worker/transport/normal_scheduling_queue.cc +++ b/src/ray/core_worker/transport/normal_scheduling_queue.cc @@ -36,21 +36,27 @@ size_t NormalSchedulingQueue::Size() const { /// Add a new task's callbacks to the worker queue. void NormalSchedulingQueue::Add( - int64_t seq_no, int64_t client_processed_up_to, + int64_t seq_no, + int64_t client_processed_up_to, std::function accept_request, std::function reject_request, - rpc::SendReplyCallback send_reply_callback, const std::string &concurrency_group_name, - const FunctionDescriptor &function_descriptor, TaskID task_id, + rpc::SendReplyCallback send_reply_callback, + const std::string &concurrency_group_name, + const FunctionDescriptor &function_descriptor, + TaskID task_id, const std::vector &dependencies) { absl::MutexLock lock(&mu_); // Normal tasks should not have ordering constraints. RAY_CHECK(seq_no == -1); // Create a InboundRequest object for the new task, and add it to the queue. - pending_normal_tasks_.push_back( - InboundRequest(std::move(accept_request), std::move(reject_request), - std::move(send_reply_callback), task_id, dependencies.size() > 0, - /*concurrency_group_name=*/"", function_descriptor)); + pending_normal_tasks_.push_back(InboundRequest(std::move(accept_request), + std::move(reject_request), + std::move(send_reply_callback), + task_id, + dependencies.size() > 0, + /*concurrency_group_name=*/"", + function_descriptor)); } // Search for an InboundRequest associated with the task that we are trying to cancel. @@ -59,7 +65,8 @@ void NormalSchedulingQueue::Add( bool NormalSchedulingQueue::CancelTaskIfFound(TaskID task_id) { absl::MutexLock lock(&mu_); for (std::deque::reverse_iterator it = pending_normal_tasks_.rbegin(); - it != pending_normal_tasks_.rend(); ++it) { + it != pending_normal_tasks_.rend(); + ++it) { if (it->TaskID() == task_id) { pending_normal_tasks_.erase(std::next(it).base()); return true; diff --git a/src/ray/core_worker/transport/normal_scheduling_queue.h b/src/ray/core_worker/transport/normal_scheduling_queue.h index 91bf8a02b..6027ce9d2 100644 --- a/src/ray/core_worker/transport/normal_scheduling_queue.h +++ b/src/ray/core_worker/transport/normal_scheduling_queue.h @@ -39,7 +39,8 @@ class NormalSchedulingQueue : public SchedulingQueue { /// Add a new task's callbacks to the worker queue. void Add( - int64_t seq_no, int64_t client_processed_up_to, + int64_t seq_no, + int64_t client_processed_up_to, std::function accept_request, std::function reject_request, rpc::SendReplyCallback send_reply_callback, diff --git a/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc b/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc index b6f9fc0e9..413f631f4 100644 --- a/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc +++ b/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc @@ -18,9 +18,11 @@ namespace ray { namespace core { OutOfOrderActorSchedulingQueue::OutOfOrderActorSchedulingQueue( - instrumented_io_context &main_io_service, DependencyWaiter &waiter, + instrumented_io_context &main_io_service, + DependencyWaiter &waiter, std::shared_ptr> pool_manager, - bool is_asyncio, int fiber_max_concurrency, + bool is_asyncio, + int fiber_max_concurrency, const std::vector &concurrency_groups) : main_thread_id_(boost::this_thread::get_id()), waiter_(waiter), @@ -59,17 +61,23 @@ size_t OutOfOrderActorSchedulingQueue::Size() const { } void OutOfOrderActorSchedulingQueue::Add( - int64_t seq_no, int64_t client_processed_up_to, + int64_t seq_no, + int64_t client_processed_up_to, std::function accept_request, std::function reject_request, - rpc::SendReplyCallback send_reply_callback, const std::string &concurrency_group_name, - const ray::FunctionDescriptor &function_descriptor, TaskID task_id, + rpc::SendReplyCallback send_reply_callback, + const std::string &concurrency_group_name, + const ray::FunctionDescriptor &function_descriptor, + TaskID task_id, const std::vector &dependencies) { RAY_CHECK(boost::this_thread::get_id() == main_thread_id_); - auto request = - InboundRequest(std::move(accept_request), std::move(reject_request), - std::move(send_reply_callback), task_id, dependencies.size() > 0, - concurrency_group_name, function_descriptor); + auto request = InboundRequest(std::move(accept_request), + std::move(reject_request), + std::move(send_reply_callback), + task_id, + dependencies.size() > 0, + concurrency_group_name, + function_descriptor); if (dependencies.size() > 0) { waiter_.Wait(dependencies, [this, request = std::move(request)]() mutable { diff --git a/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.h b/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.h index 78cce7ddd..0a18366a4 100644 --- a/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.h +++ b/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.h @@ -37,10 +37,12 @@ namespace core { class OutOfOrderActorSchedulingQueue : public SchedulingQueue { public: OutOfOrderActorSchedulingQueue( - instrumented_io_context &main_io_service, DependencyWaiter &waiter, + instrumented_io_context &main_io_service, + DependencyWaiter &waiter, std::shared_ptr> pool_manager = std::make_shared>(), - bool is_asyncio = false, int fiber_max_concurrency = 1, + bool is_asyncio = false, + int fiber_max_concurrency = 1, const std::vector &concurrency_groups = {}); void Stop() override; @@ -50,7 +52,8 @@ class OutOfOrderActorSchedulingQueue : public SchedulingQueue { size_t Size() const override; /// Add a new actor task's callbacks to the worker queue. - void Add(int64_t seq_no, int64_t client_processed_up_to, + void Add(int64_t seq_no, + int64_t client_processed_up_to, std::function accept_request, std::function reject_request, rpc::SendReplyCallback send_reply_callback, diff --git a/src/ray/core_worker/transport/scheduling_queue.h b/src/ray/core_worker/transport/scheduling_queue.h index 6913bdc00..5ea71e926 100644 --- a/src/ray/core_worker/transport/scheduling_queue.h +++ b/src/ray/core_worker/transport/scheduling_queue.h @@ -27,7 +27,8 @@ namespace core { class SchedulingQueue { public: virtual ~SchedulingQueue() = default; - virtual void Add(int64_t seq_no, int64_t client_processed_up_to, + virtual void Add(int64_t seq_no, + int64_t client_processed_up_to, std::function accept_request, std::function reject_request, rpc::SendReplyCallback send_reply_callback, diff --git a/src/ray/gcs/asio.cc b/src/ray/gcs/asio.cc index eb1724a5c..386a0811b 100644 --- a/src/ray/gcs/asio.cc +++ b/src/ray/gcs/asio.cc @@ -64,16 +64,18 @@ RedisAsioClient::RedisAsioClient(instrumented_io_context &io_service, void RedisAsioClient::operate() { if (read_requested_ && !read_in_progress_) { read_in_progress_ = true; - socket_.async_read_some(boost::asio::null_buffers(), - boost::bind(&RedisAsioClient::handle_io, this, - boost::asio::placeholders::error, false)); + socket_.async_read_some( + boost::asio::null_buffers(), + boost::bind( + &RedisAsioClient::handle_io, this, boost::asio::placeholders::error, false)); } if (write_requested_ && !write_in_progress_) { write_in_progress_ = true; - socket_.async_write_some(boost::asio::null_buffers(), - boost::bind(&RedisAsioClient::handle_io, this, - boost::asio::placeholders::error, true)); + socket_.async_write_some( + boost::asio::null_buffers(), + boost::bind( + &RedisAsioClient::handle_io, this, boost::asio::placeholders::error, true)); } } diff --git a/src/ray/gcs/gcs_client/accessor.cc b/src/ray/gcs/gcs_client/accessor.cc index e4b58461a..031ab846f 100644 --- a/src/ray/gcs/gcs_client/accessor.cc +++ b/src/ray/gcs/gcs_client/accessor.cc @@ -178,8 +178,10 @@ Status ActorInfoAccessor::AsyncGetAll( } Status ActorInfoAccessor::AsyncGetByName( - const std::string &name, const std::string &ray_namespace, - const OptionalItemCallback &callback, int64_t timeout_ms) { + const std::string &name, + const std::string &ray_namespace, + const OptionalItemCallback &callback, + int64_t timeout_ms) { RAY_LOG(DEBUG) << "Getting actor info, name = " << name; rpc::GetNamedActorInfoRequest request; request.set_name(name); @@ -215,7 +217,8 @@ Status ActorInfoAccessor::SyncGetByName(const std::string &name, } Status ActorInfoAccessor::AsyncListNamedActors( - bool all_namespaces, const std::string &ray_namespace, + bool all_namespaces, + const std::string &ray_namespace, const OptionalItemCallback> &callback, int64_t timeout_ms) { RAY_LOG(DEBUG) << "Listing actors"; @@ -237,14 +240,15 @@ Status ActorInfoAccessor::AsyncListNamedActors( } Status ActorInfoAccessor::SyncListNamedActors( - bool all_namespaces, const std::string &ray_namespace, + bool all_namespaces, + const std::string &ray_namespace, std::vector> &actors) { rpc::ListNamedActorsRequest request; request.set_all_namespaces(all_namespaces); request.set_ray_namespace(ray_namespace); rpc::ListNamedActorsReply reply; - auto status = client_impl_->GetGcsRpcClient().SyncListNamedActors(request, &reply, - GetGcsTimeoutMs()); + auto status = client_impl_->GetGcsRpcClient().SyncListNamedActors( + request, &reply, GetGcsTimeoutMs()); if (!status.ok()) { return status; } @@ -279,12 +283,13 @@ Status ActorInfoAccessor::SyncRegisterActor(const ray::TaskSpecification &task_s rpc::RegisterActorRequest request; rpc::RegisterActorReply reply; request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); - auto status = client_impl_->GetGcsRpcClient().SyncRegisterActor(request, &reply, - GetGcsTimeoutMs()); + auto status = client_impl_->GetGcsRpcClient().SyncRegisterActor( + request, &reply, GetGcsTimeoutMs()); return status; } -Status ActorInfoAccessor::AsyncKillActor(const ActorID &actor_id, bool force_kill, +Status ActorInfoAccessor::AsyncKillActor(const ActorID &actor_id, + bool force_kill, bool no_restart, const ray::gcs::StatusCallback &callback) { rpc::KillActorViaGcsRequest request; @@ -330,20 +335,20 @@ Status ActorInfoAccessor::AsyncSubscribe( << ", job id = " << actor_id.JobId(); RAY_CHECK(subscribe != nullptr) << "Failed to subscribe actor, actor id = " << actor_id; - auto fetch_data_operation = [this, actor_id, - subscribe](const StatusCallback &fetch_done) { - auto callback = [actor_id, subscribe, fetch_done]( - const Status &status, - const boost::optional &result) { - if (result) { - subscribe(actor_id, *result); - } - if (fetch_done) { - fetch_done(status); - } - }; - RAY_CHECK_OK(AsyncGet(actor_id, callback)); - }; + auto fetch_data_operation = + [this, actor_id, subscribe](const StatusCallback &fetch_done) { + auto callback = [actor_id, subscribe, fetch_done]( + const Status &status, + const boost::optional &result) { + if (result) { + subscribe(actor_id, *result); + } + if (fetch_done) { + fetch_done(status); + } + }; + RAY_CHECK_OK(AsyncGet(actor_id, callback)); + }; { absl::MutexLock lock(&mutex_); @@ -362,15 +367,16 @@ Status ActorInfoAccessor::AsyncSubscribe( << ", retrying ..."; absl::SleepFor(absl::Seconds(1)); } - return client_impl_->GetGcsSubscriber().SubscribeActor(actor_id, subscribe, - subscribe_done); + return client_impl_->GetGcsSubscriber().SubscribeActor( + actor_id, subscribe, subscribe_done); }; fetch_data_operations_[actor_id] = fetch_data_operation; } return client_impl_->GetGcsSubscriber().SubscribeActor( - actor_id, subscribe, - [fetch_data_operation, done](const Status &) { fetch_data_operation(done); }); + actor_id, subscribe, [fetch_data_operation, done](const Status &) { + fetch_data_operation(done); + }); } Status ActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id) { @@ -427,8 +433,9 @@ Status NodeInfoAccessor::RegisterSelf(const GcsNodeInfo &local_node_info, rpc::RegisterNodeRequest request; request.mutable_node_info()->CopyFrom(local_node_info); client_impl_->GetGcsRpcClient().RegisterNode( - request, [this, node_id, local_node_info, callback]( - const Status &status, const rpc::RegisterNodeReply &reply) { + request, + [this, node_id, local_node_info, callback](const Status &status, + const rpc::RegisterNodeReply &reply) { if (status.ok()) { local_node_info_.CopyFrom(local_node_info); local_node_id_ = NodeID::FromBinary(local_node_info.node_id()); @@ -835,8 +842,9 @@ Status StatsInfoAccessor::AsyncAddProfileData( rpc::AddProfileDataRequest request; request.mutable_profile_data()->CopyFrom(*data_ptr); client_impl_->GetGcsRpcClient().AddProfileData( - request, [data_ptr, node_id, callback](const Status &status, - const rpc::AddProfileDataReply &reply) { + request, + [data_ptr, node_id, callback](const Status &status, + const rpc::AddProfileDataReply &reply) { if (callback) { callback(status); } @@ -910,8 +918,9 @@ Status WorkerInfoAccessor::AsyncReportWorkerFailure( rpc::ReportWorkerFailureRequest request; request.mutable_worker_failure()->CopyFrom(*data_ptr); client_impl_->GetGcsRpcClient().ReportWorkerFailure( - request, [worker_address, callback](const Status &status, - const rpc::ReportWorkerFailureReply &reply) { + request, + [worker_address, callback](const Status &status, + const rpc::ReportWorkerFailureReply &reply) { if (callback) { callback(status); } @@ -1003,8 +1012,9 @@ Status PlacementGroupInfoAccessor::AsyncGet( rpc::GetPlacementGroupRequest request; request.set_placement_group_id(placement_group_id.Binary()); client_impl_->GetGcsRpcClient().GetPlacementGroup( - request, [placement_group_id, callback](const Status &status, - const rpc::GetPlacementGroupReply &reply) { + request, + [placement_group_id, callback](const Status &status, + const rpc::GetPlacementGroupReply &reply) { if (reply.has_placement_group_table_data()) { callback(status, reply.placement_group_table_data()); } else { @@ -1017,7 +1027,8 @@ Status PlacementGroupInfoAccessor::AsyncGet( } Status PlacementGroupInfoAccessor::AsyncGetByName( - const std::string &name, const std::string &ray_namespace, + const std::string &name, + const std::string &ray_namespace, const OptionalItemCallback &callback, int64_t timeout_ms) { RAY_LOG(DEBUG) << "Getting named placement group info, name = " << name; @@ -1070,7 +1081,8 @@ InternalKVAccessor::InternalKVAccessor(GcsClient *client_impl) : client_impl_(client_impl) {} Status InternalKVAccessor::AsyncInternalKVGet( - const std::string &ns, const std::string &key, + const std::string &ns, + const std::string &key, const OptionalItemCallback &callback) { rpc::InternalKVGetRequest req; req.set_key(key); @@ -1090,7 +1102,8 @@ Status InternalKVAccessor::AsyncInternalKVGet( Status InternalKVAccessor::AsyncInternalKVPut(const std::string &ns, const std::string &key, - const std::string &value, bool overwrite, + const std::string &value, + bool overwrite, const OptionalItemCallback &callback) { rpc::InternalKVPutRequest req; req.set_namespace_(ns); @@ -1107,7 +1120,8 @@ Status InternalKVAccessor::AsyncInternalKVPut(const std::string &ns, } Status InternalKVAccessor::AsyncInternalKVExists( - const std::string &ns, const std::string &key, + const std::string &ns, + const std::string &key, const OptionalItemCallback &callback) { rpc::InternalKVExistsRequest req; req.set_namespace_(ns); @@ -1122,7 +1136,8 @@ Status InternalKVAccessor::AsyncInternalKVExists( } Status InternalKVAccessor::AsyncInternalKVDel(const std::string &ns, - const std::string &key, bool del_by_prefix, + const std::string &key, + bool del_by_prefix, const StatusCallback &callback) { rpc::InternalKVDelRequest req; req.set_namespace_(ns); @@ -1137,7 +1152,8 @@ Status InternalKVAccessor::AsyncInternalKVDel(const std::string &ns, } Status InternalKVAccessor::AsyncInternalKVKeys( - const std::string &ns, const std::string &prefix, + const std::string &ns, + const std::string &prefix, const OptionalItemCallback> &callback) { rpc::InternalKVKeysRequest req; req.set_namespace_(ns); @@ -1155,11 +1171,17 @@ Status InternalKVAccessor::AsyncInternalKVKeys( return Status::OK(); } -Status InternalKVAccessor::Put(const std::string &ns, const std::string &key, - const std::string &value, bool overwrite, bool &added) { +Status InternalKVAccessor::Put(const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, + bool &added) { std::promise ret_promise; RAY_CHECK_OK(AsyncInternalKVPut( - ns, key, value, overwrite, + ns, + key, + value, + overwrite, [&ret_promise, &added](Status status, boost::optional added_num) { added = static_cast(added_num.value_or(0)); ret_promise.set_value(status); @@ -1167,18 +1189,20 @@ Status InternalKVAccessor::Put(const std::string &ns, const std::string &key, return ret_promise.get_future().get(); } -Status InternalKVAccessor::Keys(const std::string &ns, const std::string &prefix, +Status InternalKVAccessor::Keys(const std::string &ns, + const std::string &prefix, std::vector &value) { std::promise ret_promise; - RAY_CHECK_OK(AsyncInternalKVKeys(ns, prefix, - [&ret_promise, &value](Status status, auto &values) { - value = values.value_or(std::vector()); - ret_promise.set_value(status); - })); + RAY_CHECK_OK(AsyncInternalKVKeys( + ns, prefix, [&ret_promise, &value](Status status, auto &values) { + value = values.value_or(std::vector()); + ret_promise.set_value(status); + })); return ret_promise.get_future().get(); } -Status InternalKVAccessor::Get(const std::string &ns, const std::string &key, +Status InternalKVAccessor::Get(const std::string &ns, + const std::string &key, std::string &value) { std::promise ret_promise; RAY_CHECK_OK( @@ -1191,7 +1215,8 @@ Status InternalKVAccessor::Get(const std::string &ns, const std::string &key, return ret_promise.get_future().get(); } -Status InternalKVAccessor::Del(const std::string &ns, const std::string &key, +Status InternalKVAccessor::Del(const std::string &ns, + const std::string &key, bool del_by_prefix) { std::promise ret_promise; RAY_CHECK_OK(AsyncInternalKVDel(ns, key, del_by_prefix, [&ret_promise](Status status) { @@ -1200,7 +1225,8 @@ Status InternalKVAccessor::Del(const std::string &ns, const std::string &key, return ret_promise.get_future().get(); } -Status InternalKVAccessor::Exists(const std::string &ns, const std::string &key, +Status InternalKVAccessor::Exists(const std::string &ns, + const std::string &key, bool &exist) { std::promise ret_promise; RAY_CHECK_OK(AsyncInternalKVExists( diff --git a/src/ray/gcs/gcs_client/accessor.h b/src/ray/gcs/gcs_client/accessor.h index 85f81536e..720b6c33d 100644 --- a/src/ray/gcs/gcs_client/accessor.h +++ b/src/ray/gcs/gcs_client/accessor.h @@ -65,7 +65,8 @@ class ActorInfoAccessor { /// \param callback Callback that will be called after lookup finishes. /// \param timeout_ms RPC timeout in milliseconds. -1 means the default. /// \return Status - virtual Status AsyncGetByName(const std::string &name, const std::string &ray_namespace, + virtual Status AsyncGetByName(const std::string &name, + const std::string &ray_namespace, const OptionalItemCallback &callback, int64_t timeout_ms = -1); @@ -77,7 +78,8 @@ class ActorInfoAccessor { /// \param ray_namespace The namespace to filter to. /// \return Status. TimedOut status if RPC is timed out. /// NotFound if the name doesn't exist. - virtual Status SyncGetByName(const std::string &name, const std::string &ray_namespace, + virtual Status SyncGetByName(const std::string &name, + const std::string &ray_namespace, rpc::ActorTableData &actor_table_data); /// List all named actors from the GCS asynchronously. @@ -88,7 +90,8 @@ class ActorInfoAccessor { /// \param timeout_ms The RPC timeout in milliseconds. -1 means the default. /// \return Status virtual Status AsyncListNamedActors( - bool all_namespaces, const std::string &ray_namespace, + bool all_namespaces, + const std::string &ray_namespace, const OptionalItemCallback> &callback, int64_t timeout_ms = -1); @@ -101,7 +104,8 @@ class ActorInfoAccessor { /// \param[out] actors The pair of list of named actors. Each pair includes the /// namespace and name of the actor. \return Status. TimeOut if RPC times out. virtual Status SyncListNamedActors( - bool all_namespaces, const std::string &ray_namespace, + bool all_namespaces, + const std::string &ray_namespace, std::vector> &actors); /// Register actor to GCS asynchronously. @@ -130,7 +134,9 @@ class ActorInfoAccessor { /// \param no_restart If set to true, the killed actor will not be restarted anymore. /// \param callback Callback that will be called after the actor is destroyed. /// \return Status - virtual Status AsyncKillActor(const ActorID &actor_id, bool force_kill, bool no_restart, + virtual Status AsyncKillActor(const ActorID &actor_id, + bool force_kill, + bool no_restart, const StatusCallback &callback); /// Asynchronously request GCS to create the actor. @@ -676,7 +682,8 @@ class PlacementGroupInfoAccessor { /// \param timeout_ms The RPC timeout in milliseconds. -1 means the default. /// \return Status. virtual Status AsyncGetByName( - const std::string &placement_group_name, const std::string &ray_namespace, + const std::string &placement_group_name, + const std::string &ray_namespace, const OptionalItemCallback &callback, int64_t timeout_ms = -1); @@ -720,7 +727,8 @@ class InternalKVAccessor { /// \param callback Callback that will be called after scanning. /// \return Status virtual Status AsyncInternalKVKeys( - const std::string &ns, const std::string &prefix, + const std::string &ns, + const std::string &prefix, const OptionalItemCallback> &callback); /// Asynchronously get the value for a given key. @@ -728,7 +736,8 @@ class InternalKVAccessor { /// \param ns The namespace to lookup. /// \param key The key to lookup. /// \param callback Callback that will be called after get the value. - virtual Status AsyncInternalKVGet(const std::string &ns, const std::string &key, + virtual Status AsyncInternalKVGet(const std::string &ns, + const std::string &key, const OptionalItemCallback &callback); /// Asynchronously set the value for a given key. @@ -738,8 +747,10 @@ class InternalKVAccessor { /// \param value The value associated with the key /// \param callback Callback that will be called after the operation. /// \return Status - virtual Status AsyncInternalKVPut(const std::string &ns, const std::string &key, - const std::string &value, bool overwrite, + virtual Status AsyncInternalKVPut(const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, const OptionalItemCallback &callback); /// Asynchronously check the existence of a given key @@ -748,7 +759,8 @@ class InternalKVAccessor { /// \param key The key to check. /// \param callback Callback that will be called after the operation. /// \return Status - virtual Status AsyncInternalKVExists(const std::string &ns, const std::string &key, + virtual Status AsyncInternalKVExists(const std::string &ns, + const std::string &key, const OptionalItemCallback &callback); /// Asynchronously delete a key @@ -758,8 +770,10 @@ class InternalKVAccessor { /// \param del_by_prefix If set to be true, delete all keys with prefix as `key`. /// \param callback Callback that will be called after the operation. /// \return Status - virtual Status AsyncInternalKVDel(const std::string &ns, const std::string &key, - bool del_by_prefix, const StatusCallback &callback); + virtual Status AsyncInternalKVDel(const std::string &ns, + const std::string &key, + bool del_by_prefix, + const StatusCallback &callback); // These are sync functions of the async above @@ -771,7 +785,8 @@ class InternalKVAccessor { /// \param prefix The prefix to scan. /// \param value It's an output parameter. It'll be set to the keys with `prefix` /// \return Status - virtual Status Keys(const std::string &ns, const std::string &prefix, + virtual Status Keys(const std::string &ns, + const std::string &prefix, std::vector &value); /// Set the in the store @@ -786,8 +801,11 @@ class InternalKVAccessor { /// \param added It's an output parameter. It'll be set to be true if /// any row is added. /// \return Status - virtual Status Put(const std::string &ns, const std::string &key, - const std::string &value, bool overwrite, bool &added); + virtual Status Put(const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, + bool &added); /// Retrive the value associated with a key /// diff --git a/src/ray/gcs/gcs_client/gcs_client.cc b/src/ray/gcs/gcs_client/gcs_client.cc index 5f52bd054..9d64157f9 100644 --- a/src/ray/gcs/gcs_client/gcs_client.cc +++ b/src/ray/gcs/gcs_client/gcs_client.cc @@ -69,8 +69,9 @@ void GcsSubscriberClient::PubsubCommandBatch( req.set_subscriber_id(request.subscriber_id()); *req.mutable_commands() = request.commands(); rpc_client_->GcsSubscriberCommandBatch( - req, [callback](const Status &status, - const rpc::GcsSubscriberCommandBatchReply &batch_reply) { + req, + [callback](const Status &status, + const rpc::GcsSubscriberCommandBatchReply &batch_reply) { rpc::PubsubCommandBatchReply reply; callback(status, reply); }); @@ -97,10 +98,13 @@ Status GcsClient::Connect(instrumented_io_context &io_service) { // Connect to redis. // We don't access redis shardings in GCS client, so we set `enable_sharding_conn` // to false. - RedisClientOptions redis_client_options( - options_.redis_ip_, options_.redis_port_, options_.password_, - /*enable_sharding_conn=*/false, options_.enable_sync_conn_, - options_.enable_async_conn_, options_.enable_subscribe_conn_); + RedisClientOptions redis_client_options(options_.redis_ip_, + options_.redis_port_, + options_.password_, + /*enable_sharding_conn=*/false, + options_.enable_sync_conn_, + options_.enable_async_conn_, + options_.enable_subscribe_conn_); redis_client_ = std::make_shared(redis_client_options); RAY_CHECK_OK(redis_client_->Connect(io_service)); } else { @@ -144,7 +148,8 @@ Status GcsClient::Connect(instrumented_io_context &io_service) { // Connect to gcs service. client_call_manager_ = std::make_unique(io_service); gcs_rpc_client_ = std::make_shared( - current_gcs_server_address_.first, current_gcs_server_address_.second, + current_gcs_server_address_.first, + current_gcs_server_address_.second, *client_call_manager_, [this](rpc::GcsServiceFailureType type) { GcsServiceFailureDetected(type); }); diff --git a/src/ray/gcs/gcs_client/gcs_client.h b/src/ray/gcs/gcs_client/gcs_client.h index a06c3ab59..c27b167b1 100644 --- a/src/ray/gcs/gcs_client/gcs_client.h +++ b/src/ray/gcs/gcs_client/gcs_client.h @@ -45,9 +45,12 @@ class GcsClientOptions { /// \param ip redis service ip. /// \param port redis service port. /// \param password redis service password. - GcsClientOptions(const std::string &redis_ip, int redis_port, - const std::string &password, bool enable_sync_conn = true, - bool enable_async_conn = true, bool enable_subscribe_conn = true) + GcsClientOptions(const std::string &redis_ip, + int redis_port, + const std::string &password, + bool enable_sync_conn = true, + bool enable_async_conn = true, + bool enable_subscribe_conn = true) : redis_ip_(redis_ip), redis_port_(redis_port), password_(password), diff --git a/src/ray/gcs/gcs_client/global_state_accessor.cc b/src/ray/gcs/gcs_client/global_state_accessor.cc index 4a252d2e1..df943e9d0 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.cc +++ b/src/ray/gcs/gcs_client/global_state_accessor.cc @@ -175,8 +175,9 @@ std::unique_ptr GlobalStateAccessor::GetActorInfo(const ActorID &ac { absl::ReaderMutexLock lock(&mutex_); RAY_CHECK_OK(gcs_client_->Actors().AsyncGet( - actor_id, TransformForOptionalItemCallback(actor_table_data, - promise))); + actor_id, + TransformForOptionalItemCallback(actor_table_data, + promise))); } promise.get_future().get(); return actor_table_data; @@ -189,8 +190,9 @@ std::unique_ptr GlobalStateAccessor::GetWorkerInfo( { absl::ReaderMutexLock lock(&mutex_); RAY_CHECK_OK(gcs_client_->Workers().AsyncGet( - worker_id, TransformForOptionalItemCallback( - worker_table_data, promise))); + worker_id, + TransformForOptionalItemCallback(worker_table_data, + promise))); } promise.get_future().get(); return worker_table_data; @@ -259,7 +261,8 @@ std::unique_ptr GlobalStateAccessor::GetPlacementGroupByName( { absl::ReaderMutexLock lock(&mutex_); RAY_CHECK_OK(gcs_client_->PlacementGroups().AsyncGetByName( - placement_group_name, ray_namespace, + placement_group_name, + ray_namespace, TransformForOptionalItemCallback( placement_group_table_data, promise))); } @@ -317,7 +320,9 @@ ray::Status GlobalStateAccessor::GetNodeToConnectForDriver( // Deal with alive nodes only std::vector nodes; - std::copy_if(result.second.begin(), result.second.end(), std::back_inserter(nodes), + std::copy_if(result.second.begin(), + result.second.end(), + std::back_inserter(nodes), [](const rpc::GcsNodeInfo &node) { return node.state() == rpc::GcsNodeInfo::ALIVE; }); diff --git a/src/ray/gcs/gcs_client/global_state_accessor.h b/src/ray/gcs/gcs_client/global_state_accessor.h index 80b4d4cb6..9f948cc76 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.h +++ b/src/ray/gcs/gcs_client/global_state_accessor.h @@ -190,7 +190,9 @@ class GlobalStateAccessor { std::vector &data_vec, std::promise &promise) { return [&data_vec, &promise](const Status &status, std::vector &&result) { RAY_CHECK_OK(status); - std::transform(result.begin(), result.end(), std::back_inserter(data_vec), + std::transform(result.begin(), + result.end(), + std::back_inserter(data_vec), [](const DATA &data) { return data.SerializeAsString(); }); promise.set_value(true); }; diff --git a/src/ray/gcs/gcs_client/test/gcs_client_test.cc b/src/ray/gcs/gcs_client/test/gcs_client_test.cc index 4529262b9..994480cc5 100644 --- a/src/ray/gcs/gcs_client/test/gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/gcs_client_test.cc @@ -100,8 +100,8 @@ class GcsClientTest : public ::testing::TestWithParam { gcs::GcsClientOptions options("127.0.0.1:5397"); gcs_client_ = std::make_unique(options); } else { - gcs::GcsClientOptions options(config_.redis_address, config_.redis_port, - config_.redis_password); + gcs::GcsClientOptions options( + config_.redis_address, config_.redis_port, config_.redis_password); gcs_client_ = std::make_unique(options); } RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_)); @@ -185,8 +185,9 @@ class GcsClientTest : public ::testing::TestWithParam { const gcs::SubscribeCallback &subscribe) { std::promise promise; RAY_CHECK_OK(gcs_client_->Actors().AsyncSubscribe( - actor_id, subscribe, - [&promise](Status status) { promise.set_value(status.ok()); })); + actor_id, subscribe, [&promise](Status status) { + promise.set_value(status.ok()); + })); return WaitReady(promise.get_future(), timeout_ms_); } @@ -202,7 +203,8 @@ class GcsClientTest : public ::testing::TestWithParam { } bool RegisterActor(const std::shared_ptr &actor_table_data, - bool is_detached = true, bool skip_wait = false) { + bool is_detached = true, + bool skip_wait = false) { rpc::TaskSpec message; auto actor_id = ActorID::FromBinary(actor_table_data->actor_id()); message.set_job_id(actor_id.JobId().Binary()); @@ -247,8 +249,9 @@ class GcsClientTest : public ::testing::TestWithParam { std::promise promise; rpc::ActorTableData actor_table_data; RAY_CHECK_OK(gcs_client_->Actors().AsyncGet( - actor_id, [&actor_table_data, &promise]( - Status status, const boost::optional &result) { + actor_id, + [&actor_table_data, &promise]( + Status status, const boost::optional &result) { assert(result); actor_table_data.CopyFrom(*result); promise.set_value(true); @@ -998,7 +1001,8 @@ TEST_P(GcsClientTest, TestEvictExpiredDeadNodes) { int main(int argc, char **argv) { InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLog::ShutDownRayLog, + argv[0], ray::RayLogLevel::INFO, /*log_dir=*/""); ::testing::InitGoogleTest(&argc, argv); diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc index b0fe4efa6..f0d4b3385 100644 --- a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -82,8 +82,8 @@ class GlobalStateAccessorTest : public ::testing::TestWithParam { gcs_client_ = std::make_unique(options); global_state_ = std::make_unique(options); } else { - gcs::GcsClientOptions options(config.redis_address, config.redis_port, - config.redis_password); + gcs::GcsClientOptions options( + config.redis_address, config.redis_port, config.redis_password); gcs_client_ = std::make_unique(options); global_state_ = std::make_unique(options); } @@ -272,7 +272,8 @@ TEST_P(GlobalStateAccessorTest, TestPlacementGroupTable) { ASSERT_EQ(global_state_->GetAllPlacementGroupInfo().size(), 0); } -INSTANTIATE_TEST_SUITE_P(RedisRemovalTest, GlobalStateAccessorTest, +INSTANTIATE_TEST_SUITE_P(RedisRemovalTest, + GlobalStateAccessorTest, ::testing::Values(false, true)); } // namespace ray @@ -280,7 +281,8 @@ INSTANTIATE_TEST_SUITE_P(RedisRemovalTest, GlobalStateAccessorTest, int main(int argc, char **argv) { ray::RayLog::InstallFailureSignalHandler(argv[0]); InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLog::ShutDownRayLog, + argv[0], ray::RayLogLevel::INFO, /*log_dir=*/""); ::testing::InitGoogleTest(&argc, argv); diff --git a/src/ray/gcs/gcs_server/gcs_actor_distribution.cc b/src/ray/gcs/gcs_server/gcs_actor_distribution.cc index 747073ee4..eb2d31af1 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_distribution.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_distribution.cc @@ -33,7 +33,8 @@ const ResourceRequest &GcsActorWorkerAssignment::GetResources() const { bool GcsActorWorkerAssignment::IsShared() const { return is_shared_; } GcsBasedActorScheduler::GcsBasedActorScheduler( - instrumented_io_context &io_context, GcsActorTable &gcs_actor_table, + instrumented_io_context &io_context, + GcsActorTable &gcs_actor_table, const GcsNodeManager &gcs_node_manager, std::shared_ptr gcs_resource_manager, std::shared_ptr gcs_resource_scheduler, @@ -41,9 +42,13 @@ GcsBasedActorScheduler::GcsBasedActorScheduler( GcsActorSchedulerSuccessCallback schedule_success_handler, std::shared_ptr raylet_client_pool, rpc::ClientFactoryFn client_factory) - : GcsActorScheduler(io_context, gcs_actor_table, gcs_node_manager, - schedule_failure_handler, schedule_success_handler, - raylet_client_pool, client_factory), + : GcsActorScheduler(io_context, + gcs_actor_table, + gcs_node_manager, + schedule_failure_handler, + schedule_success_handler, + raylet_client_pool, + client_factory), gcs_resource_manager_(std::move(gcs_resource_manager)), gcs_resource_scheduler_(std::move(gcs_resource_scheduler)) {} @@ -71,15 +76,16 @@ GcsBasedActorScheduler::SelectOrAllocateActorWorkerAssignment( /*requires_object_store_memory=*/false); // If the task needs a sole actor worker assignment then allocate a new one. - return AllocateNewActorWorkerAssignment(required_resources, /*is_shared=*/false, - task_spec); + return AllocateNewActorWorkerAssignment( + required_resources, /*is_shared=*/false, task_spec); // TODO(Chong-Li): code path for actors that do not need a sole assignment. } std::unique_ptr GcsBasedActorScheduler::AllocateNewActorWorkerAssignment( - const ResourceRequest &required_resources, bool is_shared, + const ResourceRequest &required_resources, + bool is_shared, const TaskSpecification &task_spec) { // Allocate resources from cluster. auto selected_node_id = AllocateResources(required_resources); @@ -163,8 +169,10 @@ void GcsBasedActorScheduler::WarnResourceAllocationFailure( } void GcsBasedActorScheduler::HandleWorkerLeaseReply( - std::shared_ptr actor, std::shared_ptr node, - const Status &status, const rpc::RequestWorkerLeaseReply &reply) { + std::shared_ptr actor, + std::shared_ptr node, + const Status &status, + const rpc::RequestWorkerLeaseReply &reply) { auto node_id = NodeID::FromBinary(node->node_id()); // If the actor is still in the leasing map and the status is ok, remove the actor // from the leasing map and handle the reply. Otherwise, lease again, because it @@ -195,7 +203,9 @@ void GcsBasedActorScheduler::HandleWorkerLeaseReply( } if (reply.canceled()) { // TODO(sang): Should properly update the failure message. - HandleRequestWorkerLeaseCanceled(actor, node_id, reply.failure_type(), + HandleRequestWorkerLeaseCanceled(actor, + node_id, + reply.failure_type(), /*scheduling_failure_message*/ ""); } else if (reply.rejected()) { RAY_LOG(INFO) << "Failed to lease worker from node " << node_id << " for actor " diff --git a/src/ray/gcs/gcs_server/gcs_actor_distribution.h b/src/ray/gcs/gcs_server/gcs_actor_distribution.h index f1f298048..279f72709 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_distribution.h +++ b/src/ray/gcs/gcs_server/gcs_actor_distribution.h @@ -42,7 +42,8 @@ class GcsActorWorkerAssignment /// \param acquired_resources Resources owned by this gcs actor worker assignment. /// \param is_shared A flag to represent that whether the worker process can be shared. GcsActorWorkerAssignment(const NodeID &node_id, - const ResourceRequest &acquired_resources, bool is_shared); + const ResourceRequest &acquired_resources, + bool is_shared); const NodeID &GetNodeID() const; @@ -80,7 +81,8 @@ class GcsBasedActorScheduler : public GcsActorScheduler { /// \param client_factory Factory to create remote core worker client, default factor /// will be used if not set. explicit GcsBasedActorScheduler( - instrumented_io_context &io_context, GcsActorTable &gcs_actor_table, + instrumented_io_context &io_context, + GcsActorTable &gcs_actor_table, const GcsNodeManager &gcs_node_manager, std::shared_ptr gcs_resource_manager, std::shared_ptr gcs_resource_scheduler, @@ -130,7 +132,8 @@ class GcsBasedActorScheduler : public GcsActorScheduler { /// \param is_shared If the worker is shared by multiple actors or not. /// \param task_spec The specification of the task. std::unique_ptr AllocateNewActorWorkerAssignment( - const ResourceRequest &required_resources, bool is_shared, + const ResourceRequest &required_resources, + bool is_shared, const TaskSpecification &task_spec); /// Allocate resources for the actor. diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index eb2a989cb..4b2c7c257 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -59,7 +59,8 @@ const ray::rpc::ActorDeathCause GenNodeDiedCause(const ray::gcs::GcsActor *actor } const ray::rpc::ActorDeathCause GenWorkerDiedCause( - const ray::gcs::GcsActor *actor, const std::string &ip_address, + const ray::gcs::GcsActor *actor, + const std::string &ip_address, const ray::rpc::WorkerExitType &disconnect_type) { ray::rpc::ActorDeathCause death_cause; auto actor_died_error_ctx = death_cause.mutable_actor_died_error_context(); @@ -70,15 +71,20 @@ const ray::rpc::ActorDeathCause GenWorkerDiedCause( return death_cause; } const ray::rpc::ActorDeathCause GenOwnerDiedCause( - const ray::gcs::GcsActor *actor, const WorkerID &owner_id, - const ray::rpc::WorkerExitType disconnect_type, const std::string owner_ip_address) { + const ray::gcs::GcsActor *actor, + const WorkerID &owner_id, + const ray::rpc::WorkerExitType disconnect_type, + const std::string owner_ip_address) { ray::rpc::ActorDeathCause death_cause; auto actor_died_error_ctx = death_cause.mutable_actor_died_error_context(); AddActorInfo(actor, actor_died_error_ctx); - actor_died_error_ctx->set_error_message(absl::StrCat( - "The actor is dead because its owner has died. Owner Id: ", owner_id.Hex(), - " Owner Ip address: ", owner_ip_address, - " Owner worker exit type: ", ray::rpc::WorkerExitType_Name(disconnect_type))); + actor_died_error_ctx->set_error_message( + absl::StrCat("The actor is dead because its owner has died. Owner Id: ", + owner_id.Hex(), + " Owner Ip address: ", + owner_ip_address, + " Owner worker exit type: ", + ray::rpc::WorkerExitType_Name(disconnect_type))); return death_cause; } @@ -191,7 +197,8 @@ GcsActorManager::GcsActorManager( boost::asio::io_context &io_context, std::shared_ptr scheduler, std::shared_ptr gcs_table_storage, - std::shared_ptr gcs_publisher, RuntimeEnvManager &runtime_env_manager, + std::shared_ptr gcs_publisher, + RuntimeEnvManager &runtime_env_manager, GcsFunctionManager &function_manager, std::function destroy_owned_placement_group_if_needed, std::function(const JobID &)> get_job_config, @@ -234,12 +241,13 @@ void GcsActorManager::HandleRegisterActor(const rpc::RegisterActorRequest &reque RAY_LOG(INFO) << "Registering actor, job id = " << actor_id.JobId() << ", actor id = " << actor_id; Status status = - RegisterActor(request, [reply, send_reply_callback, - actor_id](const std::shared_ptr &actor) { - RAY_LOG(INFO) << "Registered actor, job id = " << actor_id.JobId() - << ", actor id = " << actor_id; - GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); - }); + RegisterActor(request, + [reply, send_reply_callback, actor_id]( + const std::shared_ptr &actor) { + RAY_LOG(INFO) << "Registered actor, job id = " << actor_id.JobId() + << ", actor id = " << actor_id; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + }); if (!status.ok()) { RAY_LOG(WARNING) << "Failed to register actor: " << status.ToString() << ", job id = " << actor_id.JobId() << ", actor id = " << actor_id; @@ -257,15 +265,16 @@ void GcsActorManager::HandleCreateActor(const rpc::CreateActorRequest &request, RAY_LOG(INFO) << "Creating actor, job id = " << actor_id.JobId() << ", actor id = " << actor_id; - Status status = CreateActor(request, [reply, send_reply_callback, actor_id]( - const std::shared_ptr &actor, - const rpc::PushTaskReply &task_reply) { - RAY_LOG(INFO) << "Finished creating actor, job id = " << actor_id.JobId() - << ", actor id = " << actor_id; - reply->mutable_actor_address()->CopyFrom(actor->GetAddress()); - reply->mutable_borrowed_refs()->CopyFrom(task_reply.borrowed_refs()); - GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); - }); + Status status = CreateActor( + request, + [reply, send_reply_callback, actor_id](const std::shared_ptr &actor, + const rpc::PushTaskReply &task_reply) { + RAY_LOG(INFO) << "Finished creating actor, job id = " << actor_id.JobId() + << ", actor id = " << actor_id; + reply->mutable_actor_address()->CopyFrom(actor->GetAddress()); + reply->mutable_borrowed_refs()->CopyFrom(task_reply.borrowed_refs()); + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + }); if (!status.ok()) { RAY_LOG(WARNING) << "Failed to create actor, job id = " << actor_id.JobId() << ", actor id = " << actor_id << ", status: " << status.ToString(); @@ -339,7 +348,8 @@ void GcsActorManager::HandleGetAllActorInfo(const rpc::GetAllActorInfoRequest &r } void GcsActorManager::HandleGetNamedActorInfo( - const rpc::GetNamedActorInfoRequest &request, rpc::GetNamedActorInfoReply *reply, + const rpc::GetNamedActorInfoRequest &request, + rpc::GetNamedActorInfoReply *reply, rpc::SendReplyCallback send_reply_callback) { const std::string &name = request.name(); const std::string &ray_namespace = request.ray_namespace(); @@ -452,8 +462,10 @@ Status GcsActorManager::RegisterActor(const ray::rpc::RegisterActorRequest &requ << actor->GetRayNamespace() << "\", ...)"; auto error_data_ptr = - gcs::CreateErrorTableData("detached_actor_anonymous_namespace", stream.str(), - absl::GetCurrentTimeNanos(), job_id); + gcs::CreateErrorTableData("detached_actor_anonymous_namespace", + stream.str(), + absl::GetCurrentTimeNanos(), + job_id); RAY_LOG(WARNING) << error_data_ptr->SerializeAsString(); RAY_CHECK_OK( @@ -489,7 +501,8 @@ Status GcsActorManager::RegisterActor(const ray::rpc::RegisterActorRequest &requ // The backend storage is supposed to be reliable, so the status must be ok. RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( - actor->GetActorID(), *actor->GetMutableActorTableData(), + actor->GetActorID(), + *actor->GetMutableActorTableData(), [this, actor](const Status &status) { // The backend storage is supposed to be reliable, so the status must be ok. RAY_CHECK_OK(status); @@ -504,8 +517,8 @@ Status GcsActorManager::RegisterActor(const ray::rpc::RegisterActorRequest &requ // the actor state to DEAD to avoid race condition. return; } - RAY_CHECK_OK(gcs_publisher_->PublishActor(actor->GetActorID(), - actor->GetActorTableData(), nullptr)); + RAY_CHECK_OK(gcs_publisher_->PublishActor( + actor->GetActorID(), actor->GetActorTableData(), nullptr)); // Invoke all callbacks for all registration requests of this actor (duplicated // requests are included) and remove all of them from // actor_to_register_callbacks_. @@ -667,8 +680,9 @@ void GcsActorManager::PollOwnerForActorOutOfScope( wait_request.set_intended_worker_id(owner_id.Binary()); wait_request.set_actor_id(actor_id.Binary()); it->second.client->WaitForActorOutOfScope( - wait_request, [this, owner_node_id, owner_id, actor_id]( - Status status, const rpc::WaitForActorOutOfScopeReply &reply) { + wait_request, + [this, owner_node_id, owner_id, actor_id]( + Status status, const rpc::WaitForActorOutOfScopeReply &reply) { if (!status.ok()) { RAY_LOG(INFO) << "Worker " << owner_id << " failed, destroying actor child, job id = " @@ -771,7 +785,8 @@ void GcsActorManager::DestroyActor(const ActorID &actor_id, std::make_shared(*mutable_actor_table_data); // The backend storage is reliable in the future, so the status must be ok. RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( - actor->GetActorID(), *actor_table_data, + actor->GetActorID(), + *actor_table_data, [this, actor_id, actor_table_data](Status status) { RAY_CHECK_OK(gcs_publisher_->PublishActor( actor_id, *GenActorDataOnlyWithStates(*actor_table_data), nullptr)); @@ -817,10 +832,14 @@ void GcsActorManager::OnWorkerDead(const ray::NodeID &node_id, const std::string &worker_ip, const rpc::WorkerExitType disconnect_type, const rpc::RayException *creation_task_exception) { - std::string message = absl::StrCat( - "Worker ", worker_id.Hex(), " on node ", node_id.Hex(), - " exits, type=", rpc::WorkerExitType_Name(disconnect_type), - ", has creation_task_exception = ", (creation_task_exception != nullptr)); + std::string message = absl::StrCat("Worker ", + worker_id.Hex(), + " on node ", + node_id.Hex(), + " exits, type=", + rpc::WorkerExitType_Name(disconnect_type), + ", has creation_task_exception = ", + (creation_task_exception != nullptr)); if (disconnect_type == rpc::WorkerExitType::INTENDED_EXIT || disconnect_type == rpc::WorkerExitType::IDLE_EXIT) { RAY_LOG(DEBUG) << message; @@ -838,8 +857,9 @@ void GcsActorManager::OnWorkerDead(const ray::NodeID &node_id, // list. const auto children_ids = owner->second.children_actor_ids; for (const auto &child_id : children_ids) { - DestroyActor(child_id, GenOwnerDiedCause(GetActor(child_id), worker_id, - disconnect_type, worker_ip)); + DestroyActor( + child_id, + GenOwnerDiedCause(GetActor(child_id), worker_id, disconnect_type, worker_ip)); } } @@ -849,8 +869,9 @@ void GcsActorManager::OnWorkerDead(const ray::NodeID &node_id, auto unresolved_actors = GetUnresolvedActorsByOwnerWorker(node_id, worker_id); for (auto &actor_id : unresolved_actors) { if (registered_actors_.count(actor_id)) { - DestroyActor(actor_id, GenOwnerDiedCause(GetActor(actor_id), worker_id, - disconnect_type, worker_ip)); + DestroyActor( + actor_id, + GenOwnerDiedCause(GetActor(actor_id), worker_id, disconnect_type, worker_ip)); } } @@ -873,8 +894,8 @@ void GcsActorManager::OnWorkerDead(const ray::NodeID &node_id, rpc::ActorDeathCause death_cause; if (creation_task_exception != nullptr) { - absl::StrAppend(&message, ": ", - creation_task_exception->formatted_exception_string()); + absl::StrAppend( + &message, ": ", creation_task_exception->formatted_exception_string()); death_cause.mutable_creation_task_failure_context()->CopyFrom( *creation_task_exception); @@ -901,15 +922,18 @@ void GcsActorManager::OnNodeDead(const NodeID &node_id, } for (const auto &[owner_id, child_id] : children_ids) { DestroyActor(child_id, - GenOwnerDiedCause(GetActor(child_id), owner_id, - rpc::WorkerExitType::NODE_DIED, node_ip_address)); + GenOwnerDiedCause(GetActor(child_id), + owner_id, + rpc::WorkerExitType::NODE_DIED, + node_ip_address)); } } // Cancel scheduling actors that haven't been created on the node. auto scheduling_actor_ids = gcs_actor_scheduler_->CancelOnNode(node_id); for (auto &actor_id : scheduling_actor_ids) { - ReconstructActor(actor_id, /*need_reschedule=*/true, + ReconstructActor(actor_id, + /*need_reschedule=*/true, GenNodeDiedCause(GetActor(actor_id), node_ip_address, node_id)); } @@ -922,7 +946,8 @@ void GcsActorManager::OnNodeDead(const NodeID &node_id, for (auto &entry : created_actors) { // Reconstruct the removed actor. ReconstructActor( - entry.second, /*need_reschedule=*/true, + entry.second, + /*need_reschedule=*/true, GenNodeDiedCause(GetActor(entry.second), node_ip_address, node_id)); } } @@ -935,14 +960,17 @@ void GcsActorManager::OnNodeDead(const NodeID &node_id, for (const auto &actor_id : actor_ids) { if (registered_actors_.count(actor_id)) { DestroyActor(actor_id, - GenOwnerDiedCause(GetActor(actor_id), owner_id, - rpc::WorkerExitType::NODE_DIED, node_ip_address)); + GenOwnerDiedCause(GetActor(actor_id), + owner_id, + rpc::WorkerExitType::NODE_DIED, + node_ip_address)); } } } } -void GcsActorManager::ReconstructActor(const ActorID &actor_id, bool need_reschedule, +void GcsActorManager::ReconstructActor(const ActorID &actor_id, + bool need_reschedule, const rpc::ActorDeathCause &death_cause) { // If the owner and this actor is dead at the same time, the actor // could've been destroyed and dereigstered before reconstruction. @@ -989,7 +1017,8 @@ void GcsActorManager::ReconstructActor(const ActorID &actor_id, bool need_resche mutable_actor_table_data->clear_resource_mapping(); // The backend storage is reliable in the future, so the status must be ok. RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( - actor_id, *mutable_actor_table_data, + actor_id, + *mutable_actor_table_data, [this, actor_id, mutable_actor_table_data](Status status) { RAY_CHECK_OK(gcs_publisher_->PublishActor( actor_id, *GenActorDataOnlyWithStates(*mutable_actor_table_data), nullptr)); @@ -1005,7 +1034,8 @@ void GcsActorManager::ReconstructActor(const ActorID &actor_id, bool need_resche // The backend storage is reliable in the future, so the status must be ok. RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( - actor_id, *mutable_actor_table_data, + actor_id, + *mutable_actor_table_data, [this, actor, actor_id, mutable_actor_table_data, death_cause](Status status) { // If actor was an detached actor, make sure to destroy it. // We need to do this because detached actors are not destroyed @@ -1092,7 +1122,8 @@ void GcsActorManager::OnActorCreationSuccess(const std::shared_ptr &ac auto actor_table_data = *mutable_actor_table_data; // The backend storage is reliable in the future, so the status must be ok. RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( - actor_id, actor_table_data, + actor_id, + actor_table_data, [this, actor_id, actor_table_data, actor, reply](Status status) { RAY_CHECK_OK(gcs_publisher_->PublishActor( actor_id, *GenActorDataOnlyWithStates(actor_table_data), nullptr)); @@ -1286,7 +1317,8 @@ void GcsActorManager::RemoveActorFromOwner(const std::shared_ptr &acto } void GcsActorManager::NotifyCoreWorkerToKillActor(const std::shared_ptr &actor, - bool force_kill, bool no_restart) { + bool force_kill, + bool no_restart) { rpc::KillActorRequest request; request.set_intended_actor_id(actor->GetActorID().Binary()); request.set_force_kill(force_kill); @@ -1299,7 +1331,8 @@ void GcsActorManager::NotifyCoreWorkerToKillActor(const std::shared_ptrGetActorID() << " hasn't been created yet, cancel scheduling " << task_id; CancelActorInScheduling(actor, task_id); - ReconstructActor(actor_id, /*need_reschedule=*/true, + ReconstructActor(actor_id, + /*need_reschedule=*/true, GenKilledByApplicationCause(GetActor(actor_id))); } } @@ -1362,7 +1396,8 @@ void GcsActorManager::CancelActorInScheduling(const std::shared_ptr &a // The actor was being scheduled and has now been canceled. RAY_CHECK(canceled_actor_id == actor_id); } else { - auto pending_it = std::find_if(pending_actors_.begin(), pending_actors_.end(), + auto pending_it = std::find_if(pending_actors_.begin(), + pending_actors_.end(), [actor_id](const std::shared_ptr &actor) { return actor->GetActorID() == actor_id; }); diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.h b/src/ray/gcs/gcs_server/gcs_actor_manager.h index 60962e98b..788241df7 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.h +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.h @@ -198,7 +198,8 @@ class GcsActorManager : public rpc::ActorInfoHandler { boost::asio::io_context &io_context, std::shared_ptr scheduler, std::shared_ptr gcs_table_storage, - std::shared_ptr gcs_publisher, RuntimeEnvManager &runtime_env_manager, + std::shared_ptr gcs_publisher, + RuntimeEnvManager &runtime_env_manager, GcsFunctionManager &function_manager, std::function destroy_ownded_placement_group_if_needed, std::function(const JobID &)> get_job_config, @@ -302,7 +303,8 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// \param exit_type exit reason of the dead worker. /// \param creation_task_exception if this arg is set, this worker is died because of an /// exception thrown in actor's creation task. - void OnWorkerDead(const NodeID &node_id, const WorkerID &worker_id, + void OnWorkerDead(const NodeID &node_id, + const WorkerID &worker_id, const std::string &worker_ip, const rpc::WorkerExitType disconnect_type, const rpc::RayException *creation_task_exception = nullptr); @@ -389,7 +391,8 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// \param[in] actor_id The actor id to destroy. /// \param[in] death_cause The reason why actor is destroyed. /// \param[in] force_kill Whether destory the actor forcelly. - void DestroyActor(const ActorID &actor_id, const rpc::ActorDeathCause &death_cause, + void DestroyActor(const ActorID &actor_id, + const rpc::ActorDeathCause &death_cause, bool force_kill = true); /// Get unresolved actors that were submitted from the specified node. @@ -408,7 +411,8 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// again. /// \param death_cause Context about why this actor is dead. Should only be set when /// need_reschedule=false. - void ReconstructActor(const ActorID &actor_id, bool need_reschedule, + void ReconstructActor(const ActorID &actor_id, + bool need_reschedule, const rpc::ActorDeathCause &death_cause); /// Remove the specified actor from `unresolved_actors_`. @@ -434,7 +438,8 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// \param force_kill Whether to force kill an actor by killing the worker. /// \param no_restart If set to true, the killed actor will not be restarted anymore. void NotifyCoreWorkerToKillActor(const std::shared_ptr &actor, - bool force_kill = true, bool no_restart = true); + bool force_kill = true, + bool no_restart = true); /// Add the destroyed actor to the cache. If the cache is full, one actor is randomly /// evicted. diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc index c66656edb..24a3aabd6 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc @@ -24,7 +24,8 @@ namespace ray { namespace gcs { GcsActorScheduler::GcsActorScheduler( - instrumented_io_context &io_context, GcsActorTable &gcs_actor_table, + instrumented_io_context &io_context, + GcsActorTable &gcs_actor_table, const GcsNodeManager &gcs_node_manager, GcsActorSchedulerFailureCallback schedule_failure_handler, GcsActorSchedulerSuccessCallback schedule_success_handler, @@ -126,7 +127,8 @@ std::vector GcsActorScheduler::CancelOnNode(const NodeID &node_id) { return actor_ids; } -void GcsActorScheduler::CancelOnLeasing(const NodeID &node_id, const ActorID &actor_id, +void GcsActorScheduler::CancelOnLeasing(const NodeID &node_id, + const ActorID &actor_id, const TaskID &task_id) { // NOTE: This method will cancel the outstanding lease request and remove leasing // information from the internal state. @@ -241,7 +243,8 @@ void GcsActorScheduler::LeaseWorkerFromNode(std::shared_ptr actor, void GcsActorScheduler::RetryLeasingWorkerFromNode( std::shared_ptr actor, std::shared_ptr node) { RAY_UNUSED(execute_after( - io_context_, [this, node, actor] { DoRetryLeasingWorkerFromNode(actor, node); }, + io_context_, + [this, node, actor] { DoRetryLeasingWorkerFromNode(actor, node); }, RayConfig::instance().gcs_lease_worker_retry_interval_ms())); } @@ -304,7 +307,8 @@ void GcsActorScheduler::HandleWorkerLeaseGrantedReply( // Without this, there could be a possible race condition. Related issues: // https://github.com/ray-project/ray/pull/9215/files#r449469320 core_worker_clients_.GetOrConnect(leased_worker->GetAddress()); - RAY_CHECK_OK(gcs_actor_table_.Put(actor->GetActorID(), actor->GetActorTableData(), + RAY_CHECK_OK(gcs_actor_table_.Put(actor->GetActorID(), + actor->GetActorTableData(), [this, actor, leased_worker](Status status) { RAY_CHECK_OK(status); CreateActorOnWorker(actor, leased_worker); @@ -313,7 +317,8 @@ void GcsActorScheduler::HandleWorkerLeaseGrantedReply( } void GcsActorScheduler::HandleRequestWorkerLeaseCanceled( - std::shared_ptr actor, const NodeID &node_id, + std::shared_ptr actor, + const NodeID &node_id, rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, const std::string &scheduling_failure_message) { RAY_LOG(INFO) @@ -393,7 +398,8 @@ void GcsActorScheduler::RetryCreatingActorOnWorker( RAY_LOG(DEBUG) << "Retry creating actor " << actor->GetActorID() << " on worker " << worker->GetWorkerID(); RAY_UNUSED(execute_after( - io_context_, [this, actor, worker] { DoRetryCreatingActorOnWorker(actor, worker); }, + io_context_, + [this, actor, worker] { DoRetryCreatingActorOnWorker(actor, worker); }, RayConfig::instance().gcs_create_actor_retry_interval_ms())); } @@ -485,8 +491,10 @@ std::shared_ptr RayletBasedActorScheduler::SelectNodeRandomly( } void RayletBasedActorScheduler::HandleWorkerLeaseReply( - std::shared_ptr actor, std::shared_ptr node, - const Status &status, const rpc::RequestWorkerLeaseReply &reply) { + std::shared_ptr actor, + std::shared_ptr node, + const Status &status, + const rpc::RequestWorkerLeaseReply &reply) { // If the actor is still in the leasing map and the status is ok, remove the actor // from the leasing map and handle the reply. Otherwise, lease again, because it // may be a network exception. @@ -517,7 +525,9 @@ void RayletBasedActorScheduler::HandleWorkerLeaseReply( if (status.ok()) { if (reply.canceled()) { HandleRequestWorkerLeaseCanceled( - actor, node_id, reply.failure_type(), + actor, + node_id, + reply.failure_type(), /*scheduling_failure_message*/ reply.scheduling_failure_message()); return; } diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h index 2ee66b1dd..dc3fc82ba 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h @@ -35,9 +35,10 @@ namespace gcs { class GcsActor; -using GcsActorSchedulerFailureCallback = std::function, rpc::RequestWorkerLeaseReply::SchedulingFailureType, - const std::string &)>; +using GcsActorSchedulerFailureCallback = + std::function, + rpc::RequestWorkerLeaseReply::SchedulingFailureType, + const std::string &)>; using GcsActorSchedulerSuccessCallback = std::function, const rpc::PushTaskReply &reply)>; @@ -63,7 +64,8 @@ class GcsActorSchedulerInterface { /// /// \param node_id ID of the node where the actor leasing request has been sent. /// \param actor_id ID of an actor. - virtual void CancelOnLeasing(const NodeID &node_id, const ActorID &actor_id, + virtual void CancelOnLeasing(const NodeID &node_id, + const ActorID &actor_id, const TaskID &task_id) = 0; /// Cancel the actor that is being scheduled to the specified worker. @@ -106,7 +108,8 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { /// \param client_factory Factory to create remote core worker client, default factor /// will be used if not set. explicit GcsActorScheduler( - instrumented_io_context &io_context, GcsActorTable &gcs_actor_table, + instrumented_io_context &io_context, + GcsActorTable &gcs_actor_table, const GcsNodeManager &gcs_node_manager, GcsActorSchedulerFailureCallback schedule_failure_handler, GcsActorSchedulerSuccessCallback schedule_success_handler, @@ -140,7 +143,8 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { /// /// \param node_id ID of the node where the actor leasing request has been sent. /// \param actor_id ID of an actor. - void CancelOnLeasing(const NodeID &node_id, const ActorID &actor_id, + void CancelOnLeasing(const NodeID &node_id, + const ActorID &actor_id, const TaskID &task_id) override; /// Cancel the actor that is being scheduled to the specified worker. @@ -272,7 +276,8 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { /// \param failure_type The type of the canceling. /// \param scheduling_failure_message The scheduling failure error message. void HandleRequestWorkerLeaseCanceled( - std::shared_ptr actor, const NodeID &node_id, + std::shared_ptr actor, + const NodeID &node_id, rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, const std::string &scheduling_failure_message); diff --git a/src/ray/gcs/gcs_server/gcs_heartbeat_manager.cc b/src/ray/gcs/gcs_server/gcs_heartbeat_manager.cc index 49ffc1e04..d2dd602c5 100644 --- a/src/ray/gcs/gcs_server/gcs_heartbeat_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_heartbeat_manager.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/gcs/gcs_server/gcs_heartbeat_manager.h" + #include "ray/common/ray_config.h" #include "ray/gcs/pb_util.h" #include "src/ray/protobuf/gcs.pb.h" @@ -73,14 +74,15 @@ void GcsHeartbeatManager::AddNode(const NodeID &node_id) { } void GcsHeartbeatManager::HandleReportHeartbeat( - const rpc::ReportHeartbeatRequest &request, rpc::ReportHeartbeatReply *reply, + const rpc::ReportHeartbeatRequest &request, + rpc::ReportHeartbeatReply *reply, rpc::SendReplyCallback send_reply_callback) { NodeID node_id = NodeID::FromBinary(request.heartbeat().node_id()); auto iter = heartbeats_.find(node_id); if (iter == heartbeats_.end()) { // Reply the raylet with an error so the raylet can crash itself. - GCS_RPC_SEND_REPLY(send_reply_callback, reply, - Status::Disconnected("Node has been dead")); + GCS_RPC_SEND_REPLY( + send_reply_callback, reply, Status::Disconnected("Node has been dead")); return; } diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index 431f0fc5d..21c5bd9ea 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -41,8 +41,8 @@ void GcsJobManager::HandleAddJob(const rpc::AddJobRequest &request, RAY_LOG(INFO) << "Adding job, job id = " << job_id << ", driver pid = " << mutable_job_table_data.driver_pid(); - auto on_done = [this, job_id, mutable_job_table_data, reply, - send_reply_callback](const Status &status) { + auto on_done = [this, job_id, mutable_job_table_data, reply, send_reply_callback]( + const Status &status) { if (!status.ok()) { RAY_LOG(ERROR) << "Failed to add job, job id = " << job_id << ", driver pid = " << mutable_job_table_data.driver_pid(); @@ -107,8 +107,9 @@ void GcsJobManager::HandleMarkJobFinished(const rpc::MarkJobFinishedRequest &req }; Status status = gcs_table_storage_->JobTable().Get( - job_id, [this, job_id, send_reply]( - Status status, const boost::optional &result) { + job_id, + [this, job_id, send_reply](Status status, + const boost::optional &result) { if (status.ok() && result) { MarkJobAsFinished(*result, send_reply); } else { diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.h b/src/ray/gcs/gcs_server/gcs_job_manager.h index e03200374..a09dd6a55 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.h +++ b/src/ray/gcs/gcs_server/gcs_job_manager.h @@ -38,7 +38,8 @@ class GcsJobManager : public rpc::JobInfoHandler { void Initialize(const GcsInitData &gcs_init_data); - void HandleAddJob(const rpc::AddJobRequest &request, rpc::AddJobReply *reply, + void HandleAddJob(const rpc::AddJobRequest &request, + rpc::AddJobReply *reply, rpc::SendReplyCallback send_reply_callback) override; void HandleMarkJobFinished(const rpc::MarkJobFinishedRequest &request, diff --git a/src/ray/gcs/gcs_server/gcs_kv_manager.cc b/src/ray/gcs/gcs_server/gcs_kv_manager.cc index cb63ae5cc..268ddbe40 100644 --- a/src/ray/gcs/gcs_server/gcs_kv_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_kv_manager.cc @@ -61,7 +61,8 @@ RedisInternalKV::RedisInternalKV(const RedisClientOptions &redis_options) RAY_CHECK_OK(redis_client_->Connect(io_service_)); } -void RedisInternalKV::Get(const std::string &ns, const std::string &key, +void RedisInternalKV::Get(const std::string &ns, + const std::string &key, std::function)> callback) { auto true_key = MakeKey(ns, key); std::vector cmd = {"HGET", true_key, "value"}; @@ -77,12 +78,14 @@ void RedisInternalKV::Get(const std::string &ns, const std::string &key, })); } -void RedisInternalKV::Put(const std::string &ns, const std::string &key, - const std::string &value, bool overwrite, +void RedisInternalKV::Put(const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, std::function callback) { auto true_key = MakeKey(ns, key); - std::vector cmd = {overwrite ? "HSET" : "HSETNX", true_key, "value", - value}; + std::vector cmd = { + overwrite ? "HSET" : "HSETNX", true_key, "value", value}; RAY_CHECK_OK(redis_client_->GetPrimaryContext()->RunArgvAsync( cmd, [callback = std::move(callback)](auto redis_reply) { if (callback) { @@ -92,8 +95,10 @@ void RedisInternalKV::Put(const std::string &ns, const std::string &key, })); } -void RedisInternalKV::Del(const std::string &ns, const std::string &key, - bool del_by_prefix, std::function callback) { +void RedisInternalKV::Del(const std::string &ns, + const std::string &key, + bool del_by_prefix, + std::function callback) { auto true_key = MakeKey(ns, key); if (del_by_prefix) { std::vector cmd = {"KEYS", true_key + "*"}; @@ -131,7 +136,8 @@ void RedisInternalKV::Del(const std::string &ns, const std::string &key, } } -void RedisInternalKV::Exists(const std::string &ns, const std::string &key, +void RedisInternalKV::Exists(const std::string &ns, + const std::string &key, std::function callback) { auto true_key = MakeKey(ns, key); std::vector cmd = {"HEXISTS", true_key, "value"}; @@ -144,7 +150,8 @@ void RedisInternalKV::Exists(const std::string &ns, const std::string &key, })); } -void RedisInternalKV::Keys(const std::string &ns, const std::string &prefix, +void RedisInternalKV::Keys(const std::string &ns, + const std::string &prefix, std::function)> callback) { auto true_prefix = MakeKey(ns, prefix); std::vector cmd = {"KEYS", true_prefix + "*"}; @@ -162,7 +169,8 @@ void RedisInternalKV::Keys(const std::string &ns, const std::string &prefix, })); } -void MemoryInternalKV::Get(const std::string &ns, const std::string &key, +void MemoryInternalKV::Get(const std::string &ns, + const std::string &key, std::function)> callback) { absl::ReaderMutexLock lock(&mu_); auto true_prefix = MakeKey(ns, key); @@ -174,8 +182,10 @@ void MemoryInternalKV::Get(const std::string &ns, const std::string &key, } } -void MemoryInternalKV::Put(const std::string &ns, const std::string &key, - const std::string &value, bool overwrite, +void MemoryInternalKV::Put(const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, std::function callback) { absl::WriterMutexLock _(&mu_); auto true_key = MakeKey(ns, key); @@ -194,8 +204,10 @@ void MemoryInternalKV::Put(const std::string &ns, const std::string &key, } } -void MemoryInternalKV::Del(const std::string &ns, const std::string &key, - bool del_by_prefix, std::function callback) { +void MemoryInternalKV::Del(const std::string &ns, + const std::string &key, + bool del_by_prefix, + std::function callback) { absl::WriterMutexLock _(&mu_); auto true_key = MakeKey(ns, key); auto it = map_.lower_bound(true_key); @@ -222,7 +234,8 @@ void MemoryInternalKV::Del(const std::string &ns, const std::string &key, } } -void MemoryInternalKV::Exists(const std::string &ns, const std::string &key, +void MemoryInternalKV::Exists(const std::string &ns, + const std::string &key, std::function callback) { absl::ReaderMutexLock lock(&mu_); auto true_key = MakeKey(ns, key); @@ -232,7 +245,8 @@ void MemoryInternalKV::Exists(const std::string &ns, const std::string &key, } } -void MemoryInternalKV::Keys(const std::string &ns, const std::string &prefix, +void MemoryInternalKV::Keys(const std::string &ns, + const std::string &prefix, std::function)> callback) { absl::ReaderMutexLock lock(&mu_); std::vector keys; @@ -249,7 +263,8 @@ void MemoryInternalKV::Keys(const std::string &ns, const std::string &prefix, } void GcsInternalKVManager::HandleInternalKVGet( - const rpc::InternalKVGetRequest &request, rpc::InternalKVGetReply *reply, + const rpc::InternalKVGetRequest &request, + rpc::InternalKVGetReply *reply, rpc::SendReplyCallback send_reply_callback) { auto status = ValidateKey(request.key()); if (!status.ok()) { @@ -260,8 +275,8 @@ void GcsInternalKVManager::HandleInternalKVGet( reply->set_value(*val); GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); } else { - GCS_RPC_SEND_REPLY(send_reply_callback, reply, - Status::NotFound("Failed to find the key")); + GCS_RPC_SEND_REPLY( + send_reply_callback, reply, Status::NotFound("Failed to find the key")); } }; kv_instance_->Get(request.namespace_(), request.key(), std::move(callback)); @@ -269,7 +284,8 @@ void GcsInternalKVManager::HandleInternalKVGet( } void GcsInternalKVManager::HandleInternalKVPut( - const rpc::InternalKVPutRequest &request, rpc::InternalKVPutReply *reply, + const rpc::InternalKVPutRequest &request, + rpc::InternalKVPutReply *reply, rpc::SendReplyCallback send_reply_callback) { auto status = ValidateKey(request.key()); if (!status.ok()) { @@ -279,13 +295,17 @@ void GcsInternalKVManager::HandleInternalKVPut( reply->set_added_num(newly_added ? 1 : 0); GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); }; - kv_instance_->Put(request.namespace_(), request.key(), request.value(), - request.overwrite(), std::move(callback)); + kv_instance_->Put(request.namespace_(), + request.key(), + request.value(), + request.overwrite(), + std::move(callback)); } } void GcsInternalKVManager::HandleInternalKVDel( - const rpc::InternalKVDelRequest &request, rpc::InternalKVDelReply *reply, + const rpc::InternalKVDelRequest &request, + rpc::InternalKVDelReply *reply, rpc::SendReplyCallback send_reply_callback) { auto status = ValidateKey(request.key()); if (!status.ok()) { @@ -295,13 +315,16 @@ void GcsInternalKVManager::HandleInternalKVDel( reply->set_deleted_num(del_num); GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); }; - kv_instance_->Del(request.namespace_(), request.key(), request.del_by_prefix(), + kv_instance_->Del(request.namespace_(), + request.key(), + request.del_by_prefix(), std::move(callback)); } } void GcsInternalKVManager::HandleInternalKVExists( - const rpc::InternalKVExistsRequest &request, rpc::InternalKVExistsReply *reply, + const rpc::InternalKVExistsRequest &request, + rpc::InternalKVExistsReply *reply, rpc::SendReplyCallback send_reply_callback) { auto status = ValidateKey(request.key()); if (!status.ok()) { @@ -317,7 +340,8 @@ void GcsInternalKVManager::HandleInternalKVExists( } void GcsInternalKVManager::HandleInternalKVKeys( - const rpc::InternalKVKeysRequest &request, rpc::InternalKVKeysReply *reply, + const rpc::InternalKVKeysRequest &request, + rpc::InternalKVKeysReply *reply, rpc::SendReplyCallback send_reply_callback) { auto status = ValidateKey(request.prefix()); if (!status.ok()) { diff --git a/src/ray/gcs/gcs_server/gcs_kv_manager.h b/src/ray/gcs/gcs_server/gcs_kv_manager.h index c5cf3a915..f9b592a60 100644 --- a/src/ray/gcs/gcs_server/gcs_kv_manager.h +++ b/src/ray/gcs/gcs_server/gcs_kv_manager.h @@ -35,7 +35,8 @@ class InternalKVInterface { /// \param ns The namespace of the key. /// \param key The key to fetch. /// \param callback Callback function. - virtual void Get(const std::string &ns, const std::string &key, + virtual void Get(const std::string &ns, + const std::string &key, std::function)> callback) = 0; /// Associate a key with the specified value. @@ -46,8 +47,10 @@ class InternalKVInterface { /// \param overwrite Whether to overwrite existing values. Otherwise, the update /// will be ignored. /// \param callback Callback function. - virtual void Put(const std::string &ns, const std::string &key, - const std::string &value, bool overwrite, + virtual void Put(const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, std::function callback) = 0; /// Delete the key from the store. @@ -57,7 +60,9 @@ class InternalKVInterface { /// \param del_by_prefix Whether to treat the key as prefix. If true, it'll /// delete all keys with `key` as the prefix. /// \param callback Callback function. - virtual void Del(const std::string &ns, const std::string &key, bool del_by_prefix, + virtual void Del(const std::string &ns, + const std::string &key, + bool del_by_prefix, std::function callback) = 0; /// Check whether the key exists in the store. @@ -65,7 +70,8 @@ class InternalKVInterface { /// \param ns The namespace of the key. /// \param key The key to be checked. /// \param callback Callback function. - virtual void Exists(const std::string &ns, const std::string &key, + virtual void Exists(const std::string &ns, + const std::string &key, std::function callback) = 0; /// Get the keys for a given prefix. @@ -73,7 +79,8 @@ class InternalKVInterface { /// \param ns The namespace of the prefix. /// \param prefix The prefix to be scaned. /// \param callback Callback function. - virtual void Keys(const std::string &ns, const std::string &prefix, + virtual void Keys(const std::string &ns, + const std::string &prefix, std::function)> callback) = 0; /// Return the event loop associated with the instance. This is where the @@ -94,19 +101,27 @@ class RedisInternalKV : public InternalKVInterface { io_thread_.reset(); } - void Get(const std::string &ns, const std::string &key, + void Get(const std::string &ns, + const std::string &key, std::function)> callback) override; - void Put(const std::string &ns, const std::string &key, const std::string &value, - bool overwrite, std::function callback) override; + void Put(const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, + std::function callback) override; - void Del(const std::string &ns, const std::string &key, bool del_by_prefix, + void Del(const std::string &ns, + const std::string &key, + bool del_by_prefix, std::function callback) override; - void Exists(const std::string &ns, const std::string &key, + void Exists(const std::string &ns, + const std::string &key, std::function callback) override; - void Keys(const std::string &ns, const std::string &prefix, + void Keys(const std::string &ns, + const std::string &prefix, std::function)> callback) override; instrumented_io_context &GetEventLoop() override { return io_service_; } @@ -123,19 +138,27 @@ class RedisInternalKV : public InternalKVInterface { class MemoryInternalKV : public InternalKVInterface { public: MemoryInternalKV(instrumented_io_context &io_context) : io_context_(io_context) {} - void Get(const std::string &ns, const std::string &key, + void Get(const std::string &ns, + const std::string &key, std::function)> callback) override; - void Put(const std::string &ns, const std::string &key, const std::string &value, - bool overwrite, std::function callback) override; + void Put(const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, + std::function callback) override; - void Del(const std::string &ns, const std::string &key, bool del_by_prefix, + void Del(const std::string &ns, + const std::string &key, + bool del_by_prefix, std::function callback) override; - void Exists(const std::string &ns, const std::string &key, + void Exists(const std::string &ns, + const std::string &key, std::function callback) override; - void Keys(const std::string &ns, const std::string &prefix, + void Keys(const std::string &ns, + const std::string &prefix, std::function)> callback) override; instrumented_io_context &GetEventLoop() override { return io_context_; } diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index 3f0184285..e86b34461 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -41,8 +41,8 @@ void GcsNodeManager::HandleRegisterNode(const rpc::RegisterNodeRequest &request, NodeID node_id = NodeID::FromBinary(request.node_info().node_id()); RAY_LOG(INFO) << "Registering node info, node id = " << node_id << ", address = " << request.node_info().node_manager_address(); - auto on_done = [this, node_id, request, reply, - send_reply_callback](const Status &status) { + auto on_done = [this, node_id, request, reply, send_reply_callback]( + const Status &status) { RAY_CHECK_OK(status); RAY_LOG(INFO) << "Finished registering node info, node id = " << node_id << ", address = " << request.node_info().node_manager_address(); @@ -92,34 +92,39 @@ void GcsNodeManager::DrainNode(const NodeID &node_id) { remote_address.set_raylet_id(node->node_id()); remote_address.set_ip_address(node->node_manager_address()); remote_address.set_port(node->node_manager_port()); - auto on_put_done = [this, remote_address = remote_address, node_id, + auto on_put_done = [this, + remote_address = remote_address, + node_id, node_info_delta = node_info_delta](const Status &status) { - auto on_resource_update_done = - [this, remote_address = std::move(remote_address), node_id, - node_info_delta = node_info_delta](const Status &status) { - auto raylet_client = raylet_client_pool_->GetOrConnectByAddress(remote_address); - RAY_CHECK(raylet_client); - // NOTE(sang): Drain API is not supposed to kill the raylet, but we are doing - // this until the proper "drain" behavior is implemented. Currently, before - // raylet is killed, it sends a drain request to GCS. That said, this can - // happen; - // - GCS updates the drain state and kills a raylet gracefully. - // - Raylet kills itself and send a drain request of itself to GCS. - // - Drain request will become a no-op in GCS. - // This behavior is redundant, but harmless. We'll keep this behavior until we - // implement the right drain behavior for the simplicity. Check - // https://github.com/ray-project/ray/pull/19350 for more details. - raylet_client->ShutdownRaylet( - node_id, /*graceful*/ true, - [this, node_id, node_info_delta = node_info_delta]( - const Status &status, const rpc::ShutdownRayletReply &reply) { - RAY_LOG(INFO) << "Raylet " << node_id << " is drained. Status " << status - << ". The information will be published to the cluster."; - /// Once the raylet is shutdown, inform all nodes that the raylet is dead. - RAY_CHECK_OK( - gcs_publisher_->PublishNodeInfo(node_id, *node_info_delta, nullptr)); - }); - }; + auto on_resource_update_done = [this, + remote_address = std::move(remote_address), + node_id, + node_info_delta = + node_info_delta](const Status &status) { + auto raylet_client = raylet_client_pool_->GetOrConnectByAddress(remote_address); + RAY_CHECK(raylet_client); + // NOTE(sang): Drain API is not supposed to kill the raylet, but we are doing + // this until the proper "drain" behavior is implemented. Currently, before + // raylet is killed, it sends a drain request to GCS. That said, this can + // happen; + // - GCS updates the drain state and kills a raylet gracefully. + // - Raylet kills itself and send a drain request of itself to GCS. + // - Drain request will become a no-op in GCS. + // This behavior is redundant, but harmless. We'll keep this behavior until we + // implement the right drain behavior for the simplicity. Check + // https://github.com/ray-project/ray/pull/19350 for more details. + raylet_client->ShutdownRaylet( + node_id, + /*graceful*/ true, + [this, node_id, node_info_delta = node_info_delta]( + const Status &status, const rpc::ShutdownRayletReply &reply) { + RAY_LOG(INFO) << "Raylet " << node_id << " is drained. Status " << status + << ". The information will be published to the cluster."; + /// Once the raylet is shutdown, inform all nodes that the raylet is dead. + RAY_CHECK_OK( + gcs_publisher_->PublishNodeInfo(node_id, *node_info_delta, nullptr)); + }); + }; RAY_CHECK_OK( gcs_table_storage_->NodeResourceTable().Delete(node_id, on_resource_update_done)); }; diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.h b/src/ray/gcs/gcs_server/gcs_node_manager.h index 00040e659..0f6419173 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/ray/gcs/gcs_server/gcs_node_manager.h @@ -49,7 +49,8 @@ class GcsNodeManager : public rpc::NodeInfoHandler { rpc::SendReplyCallback send_reply_callback) override; /// Handle unregister rpc request come from raylet. - void HandleDrainNode(const rpc::DrainNodeRequest &request, rpc::DrainNodeReply *reply, + void HandleDrainNode(const rpc::DrainNodeRequest &request, + rpc::DrainNodeReply *reply, rpc::SendReplyCallback send_reply_callback) override; /// Handle get all node info rpc request. diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc index bda679974..fa7cc6d36 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc @@ -193,7 +193,8 @@ void GcsPlacementGroupManager::RegisterPlacementGroup( AddToPendingQueue(placement_group); RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( - placement_group_id, placement_group->GetPlacementGroupTableData(), + placement_group_id, + placement_group->GetPlacementGroupTableData(), [this, placement_group_id, placement_group](Status status) { // The backend storage is supposed to be reliable, so the status must be ok. RAY_CHECK_OK(status); @@ -235,7 +236,8 @@ PlacementGroupID GcsPlacementGroupManager::GetPlacementGroupIDByName( } void GcsPlacementGroupManager::OnPlacementGroupCreationFailed( - std::shared_ptr placement_group, ExponentialBackOff backoff, + std::shared_ptr placement_group, + ExponentialBackOff backoff, bool is_feasible) { RAY_LOG(DEBUG) << "Failed to create placement group " << placement_group->GetName() << ", id: " << placement_group->GetPlacementGroupID() << ", try again."; @@ -299,7 +301,8 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( placement_group->UpdateState(rpc::PlacementGroupTableData::CREATED); auto placement_group_id = placement_group->GetPlacementGroupID(); RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( - placement_group_id, placement_group->GetPlacementGroupTableData(), + placement_group_id, + placement_group->GetPlacementGroupTableData(), [this, placement_group_id](Status status) { RAY_CHECK_OK(status); // Invoke all callbacks for all `WaitPlacementGroupUntilReady` requests of this @@ -356,8 +359,8 @@ void GcsPlacementGroupManager::SchedulePendingPlacementGroups() { placement_group, [this, backoff](std::shared_ptr placement_group, bool is_feasible) { - OnPlacementGroupCreationFailed(std::move(placement_group), backoff, - is_feasible); + OnPlacementGroupCreationFailed( + std::move(placement_group), backoff, is_feasible); }, [this](std::shared_ptr placement_group) { OnPlacementGroupCreationSuccess(std::move(placement_group)); @@ -378,34 +381,37 @@ void GcsPlacementGroupManager::HandleCreatePlacementGroup( auto placement_group = std::make_shared(request, get_ray_namespace_(job_id)); RAY_LOG(DEBUG) << "Registering placement group, " << placement_group->DebugString(); - RegisterPlacementGroup(placement_group, [reply, send_reply_callback, - placement_group](Status status) { - if (status.ok()) { - RAY_LOG(DEBUG) << "Finished registering placement group, " - << placement_group->DebugString(); - } else { - RAY_LOG(INFO) << "Failed to register placement group, " - << placement_group->DebugString() << ", cause: " << status.message(); - } - GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); - }); + RegisterPlacementGroup(placement_group, + [reply, send_reply_callback, placement_group](Status status) { + if (status.ok()) { + RAY_LOG(DEBUG) << "Finished registering placement group, " + << placement_group->DebugString(); + } else { + RAY_LOG(INFO) << "Failed to register placement group, " + << placement_group->DebugString() + << ", cause: " << status.message(); + } + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + }); ++counts_[CountType::CREATE_PLACEMENT_GROUP_REQUEST]; } void GcsPlacementGroupManager::HandleRemovePlacementGroup( const rpc::RemovePlacementGroupRequest &request, - rpc::RemovePlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::RemovePlacementGroupReply *reply, + rpc::SendReplyCallback send_reply_callback) { const auto placement_group_id = PlacementGroupID::FromBinary(request.placement_group_id()); - RemovePlacementGroup(placement_group_id, [send_reply_callback, reply, - placement_group_id](Status status) { - if (status.ok()) { - RAY_LOG(INFO) << "Placement group of an id, " << placement_group_id - << " is removed successfully."; - } - GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); - }); + RemovePlacementGroup(placement_group_id, + [send_reply_callback, reply, placement_group_id](Status status) { + if (status.ok()) { + RAY_LOG(INFO) + << "Placement group of an id, " << placement_group_id + << " is removed successfully."; + } + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + }); ++counts_[CountType::REMOVE_PLACEMENT_GROUP_REQUEST]; } @@ -452,7 +458,8 @@ void GcsPlacementGroupManager::RemovePlacementGroup( // Remove a placement group from infeasible queue if exists. auto pending_it = std::find_if( - infeasible_placement_groups_.begin(), infeasible_placement_groups_.end(), + infeasible_placement_groups_.begin(), + infeasible_placement_groups_.end(), [placement_group_id](const std::shared_ptr &placement_group) { return placement_group->GetPlacementGroupID() == placement_group_id; }); @@ -485,7 +492,8 @@ void GcsPlacementGroupManager::RemovePlacementGroup( } void GcsPlacementGroupManager::HandleGetPlacementGroup( - const rpc::GetPlacementGroupRequest &request, rpc::GetPlacementGroupReply *reply, + const rpc::GetPlacementGroupRequest &request, + rpc::GetPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) { PlacementGroupID placement_group_id = PlacementGroupID::FromBinary(request.placement_group_id()); @@ -518,7 +526,8 @@ void GcsPlacementGroupManager::HandleGetPlacementGroup( void GcsPlacementGroupManager::HandleGetNamedPlacementGroup( const rpc::GetNamedPlacementGroupRequest &request, - rpc::GetNamedPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::GetNamedPlacementGroupReply *reply, + rpc::SendReplyCallback send_reply_callback) { const std::string &name = request.name(); RAY_LOG(DEBUG) << "Getting named placement group info, name = " << name; @@ -542,7 +551,8 @@ void GcsPlacementGroupManager::HandleGetNamedPlacementGroup( void GcsPlacementGroupManager::HandleGetAllPlacementGroup( const rpc::GetAllPlacementGroupRequest &request, - rpc::GetAllPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::GetAllPlacementGroupReply *reply, + rpc::SendReplyCallback send_reply_callback) { RAY_LOG(DEBUG) << "Getting all placement group info."; auto on_done = [this, reply, send_reply_callback]( @@ -579,19 +589,20 @@ void GcsPlacementGroupManager::HandleWaitPlacementGroupUntilReady( RAY_LOG(DEBUG) << "Waiting for placement group until ready, placement group id = " << placement_group_id; - WaitPlacementGroup(placement_group_id, [reply, send_reply_callback, - placement_group_id](Status status) { - if (status.ok()) { - RAY_LOG(DEBUG) - << "Finished waiting for placement group until ready, placement group id = " - << placement_group_id; - } else { - RAY_LOG(WARNING) - << "Failed to waiting for placement group until ready, placement group id = " - << placement_group_id << ", cause: " << status.message(); - } - GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); - }); + WaitPlacementGroup( + placement_group_id, + [reply, send_reply_callback, placement_group_id](Status status) { + if (status.ok()) { + RAY_LOG(DEBUG) + << "Finished waiting for placement group until ready, placement group id = " + << placement_group_id; + } else { + RAY_LOG(WARNING) << "Failed to waiting for placement group until ready, " + "placement group id = " + << placement_group_id << ", cause: " << status.message(); + } + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + }); ++counts_[CountType::WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST]; } @@ -638,7 +649,8 @@ void GcsPlacementGroupManager::WaitPlacementGroup( } void GcsPlacementGroupManager::AddToPendingQueue( - std::shared_ptr pg, std::optional rank, + std::shared_ptr pg, + std::optional rank, std::optional exp_backer) { if (!rank) { rank = absl::GetCurrentTimeNanos(); @@ -669,7 +681,8 @@ void GcsPlacementGroupManager::AddToPendingQueue( void GcsPlacementGroupManager::RemoveFromPendingQueue(const PlacementGroupID &pg_id) { auto it = std::find_if(pending_placement_groups_.begin(), - pending_placement_groups_.end(), [&pg_id](const auto &val) { + pending_placement_groups_.end(), + [&pg_id](const auto &val) { return val.second.second->GetPlacementGroupID() == pg_id; }); // The placement group was pending scheduling, remove it from the queue. diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h index 97c63da87..877cdde05 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h @@ -244,7 +244,8 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { /// \param placement_group The placement_group whose creation task is infeasible. /// \param is_feasible whether the scheduler can be retry or not currently. void OnPlacementGroupCreationFailed(std::shared_ptr placement_group, - ExponentialBackOff backoff, bool is_feasible); + ExponentialBackOff backoff, + bool is_feasible); /// Handle placement_group creation task success. This should be called when the /// placement_group creation task has been scheduled successfully. diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc index 4dbcff427..74767788c 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc @@ -48,7 +48,8 @@ namespace gcs { GcsPlacementGroupScheduler::GcsPlacementGroupScheduler( instrumented_io_context &io_context, std::shared_ptr gcs_table_storage, - const gcs::GcsNodeManager &gcs_node_manager, GcsResourceManager &gcs_resource_manager, + const gcs::GcsNodeManager &gcs_node_manager, + GcsResourceManager &gcs_resource_manager, GcsResourceScheduler &gcs_resource_scheduler, std::shared_ptr raylet_client_pool, syncer::RaySyncer &ray_syncer) @@ -76,7 +77,8 @@ std::vector GcsScheduleStrategy::GetRequiredResourcesFromBundle ScheduleResult GcsScheduleStrategy::GenerateScheduleResult( const std::vector> &bundles, - const std::vector &selected_nodes, const SchedulingResultStatus &status) { + const std::vector &selected_nodes, + const SchedulingResultStatus &status) { ScheduleMap schedule_map; if (status == SchedulingResultStatus::SUCCESS && !selected_nodes.empty()) { RAY_CHECK(bundles.size() == selected_nodes.size()); @@ -95,8 +97,8 @@ ScheduleResult GcsStrictPackStrategy::Schedule( const auto &required_resources = GetRequiredResourcesFromBundles(bundles); const auto &scheduling_result = gcs_resource_scheduler.Schedule(required_resources, SchedulingType::STRICT_PACK); - return GenerateScheduleResult(bundles, scheduling_result.second, - scheduling_result.first); + return GenerateScheduleResult( + bundles, scheduling_result.second, scheduling_result.first); } ScheduleResult GcsPackStrategy::Schedule( @@ -109,8 +111,8 @@ ScheduleResult GcsPackStrategy::Schedule( const auto &required_resources = GetRequiredResourcesFromBundles(bundles); const auto &scheduling_result = gcs_resource_scheduler.Schedule(required_resources, SchedulingType::PACK); - return GenerateScheduleResult(bundles, scheduling_result.second, - scheduling_result.first); + return GenerateScheduleResult( + bundles, scheduling_result.second, scheduling_result.first); } ScheduleResult GcsSpreadStrategy::Schedule( @@ -120,8 +122,8 @@ ScheduleResult GcsSpreadStrategy::Schedule( const auto &required_resources = GetRequiredResourcesFromBundles(bundles); const auto &scheduling_result = gcs_resource_scheduler.Schedule(required_resources, SchedulingType::SPREAD); - return GenerateScheduleResult(bundles, scheduling_result.second, - scheduling_result.first); + return GenerateScheduleResult( + bundles, scheduling_result.second, scheduling_result.first); } ScheduleResult GcsStrictSpreadStrategy::Schedule( @@ -143,12 +145,13 @@ ScheduleResult GcsStrictSpreadStrategy::Schedule( const auto &required_resources = GetRequiredResourcesFromBundles(bundles); const auto &scheduling_result = gcs_resource_scheduler.Schedule( - required_resources, SchedulingType::STRICT_SPREAD, + required_resources, + SchedulingType::STRICT_SPREAD, /*node_filter_func=*/[&nodes_in_use](const NodeID &node_id) { return nodes_in_use.count(node_id) == 0; }); - return GenerateScheduleResult(bundles, scheduling_result.second, - scheduling_result.first); + return GenerateScheduleResult( + bundles, scheduling_result.second, scheduling_result.first); } void GcsPlacementGroupScheduler::ScheduleUnplacedBundles( @@ -173,7 +176,8 @@ void GcsPlacementGroupScheduler::ScheduleUnplacedBundles( << ", id: " << placement_group->GetPlacementGroupID() << ", bundles size = " << bundles.size(); auto scheduling_result = scheduler_strategies_[strategy]->Schedule( - bundles, GetScheduleContext(placement_group->GetPlacementGroupID()), + bundles, + GetScheduleContext(placement_group->GetPlacementGroupID()), gcs_resource_scheduler_); auto result_status = scheduling_result.first; @@ -206,12 +210,17 @@ void GcsPlacementGroupScheduler::ScheduleUnplacedBundles( // TODO(sang): The callback might not be called at all if nodes are dead. We should // handle this case properly. - PrepareResources(bundles_per_node, gcs_node_manager_.GetAliveNode(node_id), - [this, bundles_per_node, node_id, lease_status_tracker, - failure_callback, success_callback](const Status &status) { + PrepareResources(bundles_per_node, + gcs_node_manager_.GetAliveNode(node_id), + [this, + bundles_per_node, + node_id, + lease_status_tracker, + failure_callback, + success_callback](const Status &status) { for (const auto &bundle : bundles_per_node) { - lease_status_tracker->MarkPrepareRequestReturned(node_id, bundle, - status); + lease_status_tracker->MarkPrepareRequestReturned( + node_id, bundle, status); } if (lease_status_tracker->AllPrepareRequestsReturned()) { @@ -260,8 +269,9 @@ void GcsPlacementGroupScheduler::PrepareResources( << " for bundles: " << GetDebugStringForBundles(bundles); lease_client->PrepareBundleResources( - bundles, [node_id, bundles, callback]( - const Status &status, const rpc::PrepareBundleResourcesReply &reply) { + bundles, + [node_id, bundles, callback](const Status &status, + const rpc::PrepareBundleResourcesReply &reply) { auto result = reply.success() ? Status::OK() : Status::IOError("Failed to reserve resource"); if (result.ok()) { @@ -286,8 +296,9 @@ void GcsPlacementGroupScheduler::CommitResources( RAY_LOG(DEBUG) << "Committing resource to a node " << node_id << " for bundles: " << GetDebugStringForBundles(bundles); lease_client->CommitBundleResources( - bundles, [bundles, node_id, callback]( - const Status &status, const rpc::CommitBundleResourcesReply &reply) { + bundles, + [bundles, node_id, callback](const Status &status, + const rpc::CommitBundleResourcesReply &reply) { if (status.ok()) { RAY_LOG(DEBUG) << "Finished committing resource to " << node_id << " for bundles: " << GetDebugStringForBundles(bundles); @@ -370,8 +381,11 @@ void GcsPlacementGroupScheduler::CommitAllBundles( const auto &node = gcs_node_manager_.GetAliveNode(node_id); const auto &bundles_per_node = node_to_bundles.second; - auto commit_resources_callback = [this, lease_status_tracker, bundles_per_node, - node_id, schedule_failure_handler, + auto commit_resources_callback = [this, + lease_status_tracker, + bundles_per_node, + node_id, + schedule_failure_handler, schedule_success_handler](const Status &status) { for (const auto &bundle : bundles_per_node) { lease_status_tracker->MarkCommitRequestReturned(node_id, bundle, status); @@ -387,8 +401,8 @@ void GcsPlacementGroupScheduler::CommitAllBundles( ray_syncer_.Update(std::move(node_resource_change)); } if (lease_status_tracker->AllCommitRequestReturned()) { - OnAllBundleCommitRequestReturned(lease_status_tracker, schedule_failure_handler, - schedule_success_handler); + OnAllBundleCommitRequestReturned( + lease_status_tracker, schedule_failure_handler, schedule_success_handler); } }; @@ -445,11 +459,12 @@ void GcsPlacementGroupScheduler::OnAllBundlePrepareRequestReturned( {key, (*prepared_bundle_locations)[iter->BundleId()].first.Binary()}); } RAY_CHECK_OK(gcs_table_storage_->PlacementGroupScheduleTable().Put( - placement_group_id, data, - [this, schedule_success_handler, schedule_failure_handler, - lease_status_tracker](Status status) { - CommitAllBundles(lease_status_tracker, schedule_failure_handler, - schedule_success_handler); + placement_group_id, + data, + [this, schedule_success_handler, schedule_failure_handler, lease_status_tracker]( + Status status) { + CommitAllBundles( + lease_status_tracker, schedule_failure_handler, schedule_success_handler); })); } @@ -776,7 +791,8 @@ bool LeaseStatusTracker::MarkPreparePhaseStarted( } void LeaseStatusTracker::MarkPrepareRequestReturned( - const NodeID &node_id, const std::shared_ptr &bundle, + const NodeID &node_id, + const std::shared_ptr &bundle, const Status &status) { RAY_CHECK(prepare_request_returned_count_ <= bundles_to_schedule_.size()); auto leasing_bundles = node_to_bundles_when_preparing_.find(node_id); @@ -810,7 +826,8 @@ bool LeaseStatusTracker::AllPrepareRequestsSuccessful() const { } void LeaseStatusTracker::MarkCommitRequestReturned( - const NodeID &node_id, const std::shared_ptr &bundle, + const NodeID &node_id, + const std::shared_ptr &bundle, const Status &status) { commit_request_returned_count_ += 1; // If the request succeeds, record it. diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h index c9e8286e5..13b291b67 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h @@ -49,8 +49,10 @@ struct pair_hash { }; using ScheduleMap = absl::flat_hash_map; using ScheduleResult = std::pair; -using BundleLocations = absl::flat_hash_map< - BundleID, std::pair>, pair_hash>; +using BundleLocations = + absl::flat_hash_map>, + pair_hash>; class GcsPlacementGroupSchedulerInterface { public: @@ -140,7 +142,8 @@ class GcsScheduleStrategy { /// \return The scheduling result from the required resource. ScheduleResult GenerateScheduleResult( const std::vector> &bundles, - const std::vector &selected_nodes, const SchedulingResultStatus &status); + const std::vector &selected_nodes, + const SchedulingResultStatus &status); }; /// The `GcsPackStrategy` is that pack all bundles in one node as much as possible. @@ -219,7 +222,8 @@ class LeaseStatusTracker { /// \param status Status of the prepare response. /// \param void void MarkPrepareRequestReturned( - const NodeID &node_id, const std::shared_ptr &bundle, + const NodeID &node_id, + const std::shared_ptr &bundle, const Status &status); /// Used to know if all prepare requests are returned. @@ -417,7 +421,8 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { GcsPlacementGroupScheduler( instrumented_io_context &io_context, std::shared_ptr gcs_table_storage, - const GcsNodeManager &gcs_node_manager, GcsResourceManager &gcs_resource_manager, + const GcsNodeManager &gcs_node_manager, + GcsResourceManager &gcs_resource_manager, GcsResourceScheduler &gcs_resource_scheduler, std::shared_ptr raylet_client_pool, syncer::RaySyncer &ray_syncer); @@ -470,9 +475,10 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// This should be called when GCS server restarts after a failure. /// /// \param node_to_bundles Bundles used by each node. - void Initialize(const absl::flat_hash_map< - PlacementGroupID, std::vector>> - &group_to_bundles) override; + void Initialize( + const absl::flat_hash_map>> + &group_to_bundles) override; protected: /// Send bundles PREPARE requests to a node. The PREPARE requests will lock resources diff --git a/src/ray/gcs/gcs_server/gcs_redis_failure_detector.cc b/src/ray/gcs/gcs_server/gcs_redis_failure_detector.cc index c01d1e138..a576ad268 100644 --- a/src/ray/gcs/gcs_server/gcs_redis_failure_detector.cc +++ b/src/ray/gcs/gcs_server/gcs_redis_failure_detector.cc @@ -20,7 +20,8 @@ namespace ray { namespace gcs { GcsRedisFailureDetector::GcsRedisFailureDetector( - instrumented_io_context &io_service, std::shared_ptr redis_context, + instrumented_io_context &io_service, + std::shared_ptr redis_context, std::function callback) : redis_context_(redis_context), periodical_runner_(io_service), diff --git a/src/ray/gcs/gcs_server/gcs_resource_manager.cc b/src/ray/gcs/gcs_server/gcs_resource_manager.cc index b539db029..ad8a83432 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_resource_manager.cc @@ -21,7 +21,8 @@ namespace ray { namespace gcs { GcsResourceManager::GcsResourceManager( - instrumented_io_context &main_io_service, std::shared_ptr gcs_publisher, + instrumented_io_context &main_io_service, + std::shared_ptr gcs_publisher, std::shared_ptr gcs_table_storage) : periodical_runner_(main_io_service), gcs_publisher_(gcs_publisher), @@ -204,7 +205,8 @@ void GcsResourceManager::UpdateFromResourceReport(const rpc::ResourcesData &data } void GcsResourceManager::HandleReportResourceUsage( - const rpc::ReportResourceUsageRequest &request, rpc::ReportResourceUsageReply *reply, + const rpc::ReportResourceUsageRequest &request, + rpc::ReportResourceUsageReply *reply, rpc::SendReplyCallback send_reply_callback) { UpdateFromResourceReport(request.resources()); @@ -213,7 +215,8 @@ void GcsResourceManager::HandleReportResourceUsage( } void GcsResourceManager::HandleGetAllResourceUsage( - const rpc::GetAllResourceUsageRequest &request, rpc::GetAllResourceUsageReply *reply, + const rpc::GetAllResourceUsageRequest &request, + rpc::GetAllResourceUsageReply *reply, rpc::SendReplyCallback send_reply_callback) { if (!node_resource_usages_.empty()) { auto batch = std::make_shared(); @@ -299,8 +302,8 @@ void GcsResourceManager::Initialize(const GcsInitData &gcs_init_data) { if (iter != cluster_scheduling_resources_.end()) { auto node_resources = iter->second->GetMutableLocalView(); for (const auto &resource : entry.second.items()) { - UpdateResourceCapacity(node_resources, resource.first, - resource.second.resource_capacity()); + UpdateResourceCapacity( + node_resources, resource.first, resource.second.resource_capacity()); } } } @@ -361,8 +364,9 @@ void GcsResourceManager::OnNodeAdd(const rpc::GcsNodeInfo &node) { node.resources_total().begin(), node.resources_total().end()); // Update the cluster scheduling resources as new node is added. cluster_scheduling_resources_.emplace( - node_id, std::make_shared( - ResourceMapToNodeResources(resource_mapping, resource_mapping))); + node_id, + std::make_shared( + ResourceMapToNodeResources(resource_mapping, resource_mapping))); } } diff --git a/src/ray/gcs/gcs_server/gcs_resource_manager.h b/src/ray/gcs/gcs_server/gcs_resource_manager.h index bee41f389..d82507ad8 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_manager.h +++ b/src/ray/gcs/gcs_server/gcs_resource_manager.h @@ -161,7 +161,8 @@ class GcsResourceManager : public rpc::NodeResourceInfoHandler { const std::vector &resource_names); void UpdateResourceCapacity(NodeResources *node_resources, - const std::string &resource_name, double capacity); + const std::string &resource_name, + double capacity); /// The runner to run function periodically. PeriodicalRunner periodical_runner_; diff --git a/src/ray/gcs/gcs_server/gcs_resource_report_poller.cc b/src/ray/gcs/gcs_server/gcs_resource_report_poller.cc index a9cfadfae..0ae5f3fa3 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_report_poller.cc +++ b/src/ray/gcs/gcs_server/gcs_resource_report_poller.cc @@ -22,7 +22,8 @@ GcsResourceReportPoller::GcsResourceReportPoller( std::function handle_resource_report, std::function get_current_time_milli, std::function &, + const rpc::Address &, + std::shared_ptr &, std::function)> request_report) : ticker_(polling_service_), @@ -52,7 +53,8 @@ void GcsResourceReportPoller::Start() { "the cluster has stopped"; }}); ticker_.RunFnPeriodically( - [this] { TryPullResourceReport(); }, 10, + [this] { TryPullResourceReport(); }, + 10, "GcsResourceReportPoller.deadline_timer.pull_resource_report"); } @@ -75,9 +77,10 @@ void GcsResourceReportPoller::HandleNodeAdded(const rpc::GcsNodeInfo &node_info) address.set_ip_address(node_info.node_manager_address()); address.set_port(node_info.node_manager_port()); - auto state = - std::make_shared(NodeID::FromBinary(node_info.node_id()), - std::move(address), -1, get_current_time_milli_()); + auto state = std::make_shared(NodeID::FromBinary(node_info.node_id()), + std::move(address), + -1, + get_current_time_milli_()); const auto &node_id = state->node_id; @@ -128,7 +131,8 @@ void GcsResourceReportPoller::PullResourceReport(const std::shared_ptraddress, raylet_client_pool_, + state->address, + raylet_client_pool_, [this, state](const Status &status, const rpc::RequestResourceReportReply &reply) { if (status.ok()) { // TODO (Alex): This callback is always posted onto the main thread. Since most diff --git a/src/ray/gcs/gcs_server/gcs_resource_report_poller.h b/src/ray/gcs/gcs_server/gcs_resource_report_poller.h index e41d52c36..d5bedafd1 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_report_poller.h +++ b/src/ray/gcs/gcs_server/gcs_resource_report_poller.h @@ -54,14 +54,14 @@ class GcsResourceReportPoller { std::function get_current_time_milli = []() { return absl::GetCurrentTimeNanos() / (1000 * 1000); }, std::function &, + const rpc::Address &, + std::shared_ptr &, std::function)> request_report = [](const rpc::Address &address, std::shared_ptr &raylet_client_pool, std::function - callback) { + const rpc::RequestResourceReportReply &)> callback) { auto raylet_client = raylet_client_pool->GetOrConnectByAddress(address); raylet_client->RequestResourceReport(callback); }); @@ -105,7 +105,8 @@ class GcsResourceReportPoller { std::function get_current_time_milli_; // Send the `RequestResourceReport` RPC. std::function &, + const rpc::Address &, + std::shared_ptr &, std::function)> request_report_; // The minimum delay between two pull requests to the same thread. @@ -117,7 +118,9 @@ class GcsResourceReportPoller { int64_t last_pull_time; int64_t next_pull_time; - PullState(NodeID _node_id, rpc::Address _address, int64_t _last_pull_time, + PullState(NodeID _node_id, + rpc::Address _address, + int64_t _last_pull_time, int64_t _next_pull_time) : node_id(_node_id), address(_address), diff --git a/src/ray/gcs/gcs_server/gcs_resource_scheduler.cc b/src/ray/gcs/gcs_server/gcs_resource_scheduler.cc index 9cc516a54..4baccf6bc 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_resource_scheduler.cc @@ -32,7 +32,8 @@ double LeastResourceScorer::Score(const ResourceRequest &required_resources, if (!node_resources.normal_task_resources.IsEmpty()) { new_node_resources = node_resources; for (size_t i = 0; - i < node_resources.normal_task_resources.predefined_resources.size(); ++i) { + i < node_resources.normal_task_resources.predefined_resources.size(); + ++i) { new_node_resources.predefined_resources[i].available -= node_resources.normal_task_resources.predefined_resources[i]; if (new_node_resources.predefined_resources[i].available < 0) { @@ -330,7 +331,8 @@ SchedulingResult GcsResourceScheduler::StrictPackSchedule( const auto &cluster_resource = GetResourceView(); const auto &right_node_it = std::find_if( - cluster_resource.begin(), cluster_resource.end(), + cluster_resource.begin(), + cluster_resource.end(), [&aggregated_resource_request](const auto &entry) { return entry.second->GetLocalView().IsAvailable(aggregated_resource_request); }); diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index d925f7736..42526a1b7 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -38,8 +38,10 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, : config_(config), storage_type_(StorageType()), main_service_(main_service), - rpc_server_(config.grpc_server_name, config.grpc_server_port, - config.node_ip_address == "127.0.0.1", config.grpc_server_thread_num, + rpc_server_(config.grpc_server_name, + config.grpc_server_port, + config.node_ip_address == "127.0.0.1", + config.grpc_server_thread_num, /*keepalive_time_ms=*/RayConfig::instance().grpc_keepalive_time_ms()), client_call_manager_(main_service, RayConfig::instance().gcs_server_rpc_client_thread_num()), @@ -68,8 +70,8 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, }; ray::rpc::StoredConfig stored_config; stored_config.set_config(config_.raylet_config_list); - RAY_CHECK_OK(gcs_table_storage_->InternalConfigTable().Put(ray::UniqueID::Nil(), - stored_config, on_done)); + RAY_CHECK_OK(gcs_table_storage_->InternalConfigTable().Put( + ray::UniqueID::Nil(), stored_config, on_done)); // Here we need to make sure the Put of internal config is happening in sync // way. But since the storage API is async, we need to run the main_service_ // to block currenct thread. @@ -110,8 +112,10 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, GcsServer::~GcsServer() { Stop(); } RedisClientOptions GcsServer::GetRedisClientOptions() const { - return RedisClientOptions(config_.redis_address, config_.redis_port, - config_.redis_password, config_.enable_sharding_conn); + return RedisClientOptions(config_.redis_address, + config_.redis_port, + config_.redis_password, + config_.enable_sharding_conn); } void GcsServer::Start() { @@ -223,8 +227,8 @@ void GcsServer::Stop() { void GcsServer::InitGcsNodeManager(const GcsInitData &gcs_init_data) { RAY_CHECK(gcs_table_storage_ && gcs_publisher_); - gcs_node_manager_ = std::make_shared(gcs_publisher_, gcs_table_storage_, - raylet_client_pool_); + gcs_node_manager_ = std::make_shared( + gcs_publisher_, gcs_table_storage_, raylet_client_pool_); // Initialize by gcs tables data. gcs_node_manager_->Initialize(gcs_init_data); // Register service. @@ -292,8 +296,8 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { // gcs_actor_scheduler will treat it as failed and invoke this handler. In // this case, the actor manager should schedule the actor once an // eligible node is registered. - gcs_actor_manager_->OnActorSchedulingFailed(std::move(actor), failure_type, - scheduling_failure_message); + gcs_actor_manager_->OnActorSchedulingFailed( + std::move(actor), failure_type, scheduling_failure_message); }; auto schedule_success_handler = [this](std::shared_ptr actor, const rpc::PushTaskReply &reply) { @@ -305,19 +309,32 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { if (RayConfig::instance().gcs_actor_scheduling_enabled()) { RAY_CHECK(gcs_resource_manager_ && gcs_resource_scheduler_); - scheduler = std::make_unique( - main_service_, gcs_table_storage_->ActorTable(), *gcs_node_manager_, - gcs_resource_manager_, gcs_resource_scheduler_, schedule_failure_handler, - schedule_success_handler, raylet_client_pool_, client_factory); + scheduler = std::make_unique(main_service_, + gcs_table_storage_->ActorTable(), + *gcs_node_manager_, + gcs_resource_manager_, + gcs_resource_scheduler_, + schedule_failure_handler, + schedule_success_handler, + raylet_client_pool_, + client_factory); } else { - scheduler = std::make_unique( - main_service_, gcs_table_storage_->ActorTable(), *gcs_node_manager_, - schedule_failure_handler, schedule_success_handler, raylet_client_pool_, - client_factory); + scheduler = + std::make_unique(main_service_, + gcs_table_storage_->ActorTable(), + *gcs_node_manager_, + schedule_failure_handler, + schedule_success_handler, + raylet_client_pool_, + client_factory); } gcs_actor_manager_ = std::make_shared( - main_service_, std::move(scheduler), gcs_table_storage_, gcs_publisher_, - *runtime_env_manager_, *function_manager_, + main_service_, + std::move(scheduler), + gcs_table_storage_, + gcs_publisher_, + *runtime_env_manager_, + *function_manager_, [this](const ActorID &actor_id) { gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(actor_id); }, @@ -351,12 +368,19 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { void GcsServer::InitGcsPlacementGroupManager(const GcsInitData &gcs_init_data) { RAY_CHECK(gcs_table_storage_ && gcs_node_manager_); - auto scheduler = std::make_shared( - main_service_, gcs_table_storage_, *gcs_node_manager_, *gcs_resource_manager_, - *gcs_resource_scheduler_, raylet_client_pool_, *ray_syncer_); + auto scheduler = std::make_shared(main_service_, + gcs_table_storage_, + *gcs_node_manager_, + *gcs_resource_manager_, + *gcs_resource_scheduler_, + raylet_client_pool_, + *ray_syncer_); gcs_placement_group_manager_ = std::make_shared( - main_service_, scheduler, gcs_table_storage_, *gcs_resource_manager_, + main_service_, + scheduler, + gcs_table_storage_, + *gcs_resource_manager_, [this](const JobID &job_id) { return gcs_job_manager_->GetJobConfig(job_id)->ray_namespace(); }); @@ -406,8 +430,8 @@ void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) { raylet -> syncer::poller --> syncer::update -> gcs_resource_manager gcs_placement_scheduler --/ */ - ray_syncer_ = std::make_unique(main_service_, raylet_client_pool_, - *gcs_resource_manager_); + ray_syncer_ = std::make_unique( + main_service_, raylet_client_pool_, *gcs_resource_manager_); ray_syncer_->Initialize(gcs_init_data); ray_syncer_->Start(); } @@ -472,7 +496,9 @@ void GcsServer::InitRuntimeEnvManager() { } else { auto uri = plugin_uri.substr(protocol_pos); this->kv_manager_->GetInstance().Del( - "" /* namespace */, uri /* key */, false /* del_by_prefix*/, + "" /* namespace */, + uri /* key */, + false /* del_by_prefix*/, [callback = std::move(callback)](int64_t) { callback(false); }); } } @@ -523,7 +549,9 @@ void GcsServer::InstallEventListeners() { if (worker_failure_data->has_creation_task_exception()) { creation_task_exception = &worker_failure_data->creation_task_exception(); } - gcs_actor_manager_->OnWorkerDead(node_id, worker_id, worker_ip, + gcs_actor_manager_->OnWorkerDead(node_id, + worker_id, + worker_ip, worker_failure_data->exit_type(), creation_task_exception); }); @@ -552,7 +580,8 @@ void GcsServer::RecordMetrics() const { gcs_actor_manager_->RecordMetrics(); gcs_placement_group_manager_->RecordMetrics(); execute_after( - main_service_, [this] { RecordMetrics(); }, + main_service_, + [this] { RecordMetrics(); }, (RayConfig::instance().metrics_report_interval_ms() / 2) /* milliseconds */); } diff --git a/src/ray/gcs/gcs_server/gcs_server_main.cc b/src/ray/gcs/gcs_server/gcs_server_main.cc index e2123643b..6016dbf31 100644 --- a/src/ray/gcs/gcs_server/gcs_server_main.cc +++ b/src/ray/gcs/gcs_server/gcs_server_main.cc @@ -35,8 +35,10 @@ DEFINE_string(node_ip_address, "", "The ip address of the node."); int main(int argc, char *argv[]) { InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], - ray::RayLogLevel::INFO, /*log_dir=*/""); + ray::RayLog::ShutDownRayLog, + argv[0], + ray::RayLogLevel::INFO, + /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -70,7 +72,8 @@ int main(int argc, char *argv[]) { // Initialize event framework. if (RayConfig::instance().event_log_reporter_enabled() && !log_dir.empty()) { ray::RayEventInit(ray::rpc::Event_SourceType::Event_SourceType_GCS, - absl::flat_hash_map(), log_dir, + absl::flat_hash_map(), + log_dir, RayConfig::instance().event_level()); } diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.cc b/src/ray/gcs/gcs_server/gcs_table_storage.cc index 3044344e5..a8e3d9a43 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.cc +++ b/src/ray/gcs/gcs_server/gcs_table_storage.cc @@ -22,10 +22,11 @@ namespace ray { namespace gcs { template -Status GcsTable::Put(const Key &key, const Data &value, +Status GcsTable::Put(const Key &key, + const Data &value, const StatusCallback &callback) { - return store_client_->AsyncPut(table_name_, key.Binary(), value.SerializeAsString(), - callback); + return store_client_->AsyncPut( + table_name_, key.Binary(), value.SerializeAsString(), callback); } template @@ -71,16 +72,19 @@ Status GcsTable::BatchDelete(const std::vector &keys, for (auto &key : keys) { keys_to_delete.emplace_back(std::move(key.Binary())); } - return this->store_client_->AsyncBatchDelete(this->table_name_, keys_to_delete, - callback); + return this->store_client_->AsyncBatchDelete( + this->table_name_, keys_to_delete, callback); } template -Status GcsTableWithJobId::Put(const Key &key, const Data &value, +Status GcsTableWithJobId::Put(const Key &key, + const Data &value, const StatusCallback &callback) { - return this->store_client_->AsyncPutWithIndex(this->table_name_, key.Binary(), + return this->store_client_->AsyncPutWithIndex(this->table_name_, + key.Binary(), GetJobIdFromKey(key).Binary(), - value.SerializeAsString(), callback); + value.SerializeAsString(), + callback); } template @@ -95,15 +99,15 @@ Status GcsTableWithJobId::GetByJobId(const JobID &job_id, } callback(std::move(values)); }; - return this->store_client_->AsyncGetByIndex(this->table_name_, job_id.Binary(), - on_done); + return this->store_client_->AsyncGetByIndex( + this->table_name_, job_id.Binary(), on_done); } template Status GcsTableWithJobId::DeleteByJobId(const JobID &job_id, const StatusCallback &callback) { - return this->store_client_->AsyncDeleteByIndex(this->table_name_, job_id.Binary(), - callback); + return this->store_client_->AsyncDeleteByIndex( + this->table_name_, job_id.Binary(), callback); } template @@ -122,8 +126,8 @@ Status GcsTableWithJobId::BatchDelete(const std::vector &keys, keys_to_delete.push_back(key.Binary()); indexs_to_delete.push_back(GetJobIdFromKey(key).Binary()); } - return this->store_client_->AsyncBatchDeleteWithIndex(this->table_name_, keys_to_delete, - indexs_to_delete, callback); + return this->store_client_->AsyncBatchDeleteWithIndex( + this->table_name_, keys_to_delete, indexs_to_delete, callback); } template class GcsTable; diff --git a/src/ray/gcs/gcs_server/gcs_worker_manager.cc b/src/ray/gcs/gcs_server/gcs_worker_manager.cc index 1fb784a77..f14225a1f 100644 --- a/src/ray/gcs/gcs_server/gcs_worker_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_worker_manager.cc @@ -20,16 +20,22 @@ namespace ray { namespace gcs { void GcsWorkerManager::HandleReportWorkerFailure( - const rpc::ReportWorkerFailureRequest &request, rpc::ReportWorkerFailureReply *reply, + const rpc::ReportWorkerFailureRequest &request, + rpc::ReportWorkerFailureReply *reply, rpc::SendReplyCallback send_reply_callback) { const rpc::Address worker_address = request.worker_failure().worker_address(); const auto worker_id = WorkerID::FromBinary(worker_address.worker_id()); const auto node_id = NodeID::FromBinary(worker_address.raylet_id()); - std::string message = absl::StrCat( - "Reporting worker exit, worker id = ", worker_id.Hex(), - ", node id = ", node_id.Hex(), ", address = ", worker_address.ip_address(), - ", exit_type = ", rpc::WorkerExitType_Name(request.worker_failure().exit_type()), - request.worker_failure().has_creation_task_exception()); + std::string message = + absl::StrCat("Reporting worker exit, worker id = ", + worker_id.Hex(), + ", node id = ", + node_id.Hex(), + ", address = ", + worker_address.ip_address(), + ", exit_type = ", + rpc::WorkerExitType_Name(request.worker_failure().exit_type()), + request.worker_failure().has_creation_task_exception()); if (request.worker_failure().exit_type() == rpc::WorkerExitType::INTENDED_EXIT || request.worker_failure().exit_type() == rpc::WorkerExitType::IDLE_EXIT) { RAY_LOG(DEBUG) << message; @@ -47,7 +53,12 @@ void GcsWorkerManager::HandleReportWorkerFailure( listener(worker_failure_data); } - auto on_done = [this, worker_address, worker_id, node_id, worker_failure_data, reply, + auto on_done = [this, + worker_address, + worker_id, + node_id, + worker_failure_data, + reply, send_reply_callback](const Status &status) { if (!status.ok()) { RAY_LOG(ERROR) << "Failed to report worker failure, worker id = " << worker_id @@ -100,7 +111,8 @@ void GcsWorkerManager::HandleGetWorkerInfo(const rpc::GetWorkerInfoRequest &requ } void GcsWorkerManager::HandleGetAllWorkerInfo( - const rpc::GetAllWorkerInfoRequest &request, rpc::GetAllWorkerInfoReply *reply, + const rpc::GetAllWorkerInfoRequest &request, + rpc::GetAllWorkerInfoReply *reply, rpc::SendReplyCallback send_reply_callback) { RAY_LOG(DEBUG) << "Getting all worker info."; auto on_done = [reply, send_reply_callback]( @@ -125,15 +137,15 @@ void GcsWorkerManager::HandleAddWorkerInfo(const rpc::AddWorkerInfoRequest &requ auto worker_id = WorkerID::FromBinary(worker_data->worker_address().worker_id()); RAY_LOG(DEBUG) << "Adding worker " << worker_id; - auto on_done = [worker_id, worker_data, reply, - send_reply_callback](const Status &status) { - if (!status.ok()) { - RAY_LOG(ERROR) << "Failed to add worker information, " - << worker_data->DebugString(); - } - RAY_LOG(DEBUG) << "Finished adding worker " << worker_id; - GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); - }; + auto on_done = + [worker_id, worker_data, reply, send_reply_callback](const Status &status) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to add worker information, " + << worker_data->DebugString(); + } + RAY_LOG(DEBUG) << "Finished adding worker " << worker_id; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + }; Status status = gcs_table_storage_->WorkerTable().Put(worker_id, *worker_data, on_done); if (!status.ok()) { diff --git a/src/ray/gcs/gcs_server/grpc_based_resource_broadcaster.cc b/src/ray/gcs/gcs_server/grpc_based_resource_broadcaster.cc index 0371bd11e..5dc8bf3ce 100644 --- a/src/ray/gcs/gcs_server/grpc_based_resource_broadcaster.cc +++ b/src/ray/gcs/gcs_server/grpc_based_resource_broadcaster.cc @@ -22,7 +22,8 @@ namespace gcs { GrpcBasedResourceBroadcaster::GrpcBasedResourceBroadcaster( std::shared_ptr raylet_client_pool, std::function &, std::string &, + std::shared_ptr &, + std::string &, const rpc::ClientCallback &)> send_batch diff --git a/src/ray/gcs/gcs_server/grpc_based_resource_broadcaster.h b/src/ray/gcs/gcs_server/grpc_based_resource_broadcaster.h index 7f37ccd96..4b52c0fb6 100644 --- a/src/ray/gcs/gcs_server/grpc_based_resource_broadcaster.h +++ b/src/ray/gcs/gcs_server/grpc_based_resource_broadcaster.h @@ -28,7 +28,8 @@ class GrpcBasedResourceBroadcaster { std::shared_ptr raylet_client_pool, /* Default values should only be changed for testing. */ std::function &, std::string &, + std::shared_ptr &, + std::string &, const rpc::ClientCallback &)> send_batch = [](const rpc::Address &address, @@ -60,7 +61,8 @@ class GrpcBasedResourceBroadcaster { // The shared, thread safe pool of raylet clients, which we use to minimize connections. std::shared_ptr raylet_client_pool_; - std::function &, + std::function &, std::string &, const rpc::ClientCallback &)> send_batch_; diff --git a/src/ray/gcs/gcs_server/pubsub_handler.cc b/src/ray/gcs/gcs_server/pubsub_handler.cc index a765597b1..95091f23e 100644 --- a/src/ray/gcs/gcs_server/pubsub_handler.cc +++ b/src/ray/gcs/gcs_server/pubsub_handler.cc @@ -36,7 +36,8 @@ void InternalPubSubHandler::HandleGcsPublish(const rpc::GcsPublishRequest &reque send_reply_callback( Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with " "system config `gcs_grpc_based_pubsub=True`"), - nullptr, nullptr); + nullptr, + nullptr); return; } for (const auto &msg : request.pub_messages()) { @@ -49,13 +50,15 @@ void InternalPubSubHandler::HandleGcsPublish(const rpc::GcsPublishRequest &reque // and convert the reply to rpc::PubsubLongPollingReply because GCS RPC services are // required to have the `status` field in replies. void InternalPubSubHandler::HandleGcsSubscriberPoll( - const rpc::GcsSubscriberPollRequest &request, rpc::GcsSubscriberPollReply *reply, + const rpc::GcsSubscriberPollRequest &request, + rpc::GcsSubscriberPollReply *reply, rpc::SendReplyCallback send_reply_callback) { if (gcs_publisher_ == nullptr) { send_reply_callback( Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with " "system config `gcs_grpc_based_pubsub=True`"), - nullptr, nullptr); + nullptr, + nullptr); return; } rpc::PubsubLongPollingRequest pubsub_req; @@ -63,8 +66,10 @@ void InternalPubSubHandler::HandleGcsSubscriberPoll( auto pubsub_reply = std::make_shared(); auto pubsub_reply_ptr = pubsub_reply.get(); gcs_publisher_->GetPublisher()->ConnectToSubscriber( - pubsub_req, pubsub_reply_ptr, - [reply, reply_cb = std::move(send_reply_callback), + pubsub_req, + pubsub_reply_ptr, + [reply, + reply_cb = std::move(send_reply_callback), pubsub_reply = std::move(pubsub_reply)](ray::Status status, std::function success_cb, std::function failure_cb) { @@ -84,18 +89,21 @@ void InternalPubSubHandler::HandleGcsSubscriberCommandBatch( send_reply_callback( Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with " "system config `gcs_grpc_based_pubsub=True`"), - nullptr, nullptr); + nullptr, + nullptr); return; } const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id()); for (const auto &command : request.commands()) { if (command.has_unsubscribe_message()) { gcs_publisher_->GetPublisher()->UnregisterSubscription( - command.channel_type(), subscriber_id, + command.channel_type(), + subscriber_id, command.key_id().empty() ? std::nullopt : std::make_optional(command.key_id())); } else if (command.has_subscribe_message()) { gcs_publisher_->GetPublisher()->RegisterSubscription( - command.channel_type(), subscriber_id, + command.channel_type(), + subscriber_id, command.key_id().empty() ? std::nullopt : std::make_optional(command.key_id())); } else { RAY_LOG(FATAL) << "Invalid command has received, " diff --git a/src/ray/gcs/gcs_server/ray_syncer.h b/src/ray/gcs/gcs_server/ray_syncer.h index 7f9a210eb..bc0337029 100644 --- a/src/ray/gcs/gcs_server/ray_syncer.h +++ b/src/ray/gcs/gcs_server/ray_syncer.h @@ -66,7 +66,8 @@ class RaySyncer { static auto max_batch = RayConfig::instance().resource_broadcast_batch_size(); // Prepare the to-be-sent messages. for (size_t cnt = resources_buffer_proto_.batch().size(); - cnt < max_batch && cnt < resources_buffer_.size(); ++ptr, ++cnt) { + cnt < max_batch && cnt < resources_buffer_.size(); + ++ptr, ++cnt) { resources_buffer_proto_.add_batch()->mutable_data()->Swap(&ptr->second); } resources_buffer_.erase(beg, ptr); diff --git a/src/ray/gcs/gcs_server/stats_handler_impl.cc b/src/ray/gcs/gcs_server/stats_handler_impl.cc index 0bf86a792..2d9425140 100644 --- a/src/ray/gcs/gcs_server/stats_handler_impl.cc +++ b/src/ray/gcs/gcs_server/stats_handler_impl.cc @@ -41,8 +41,8 @@ void DefaultStatsHandler::HandleAddProfileData(const AddProfileDataRequest &requ // record when it reaches the upper limit. When we receive a record, we update the // `cursor_` and get the corresponding id through it. Put operation will directly cover // the previous data, so that we can avoid a delete operation. - Status status = gcs_table_storage_->ProfileTable().Put(ids_[cursor_++ % ids_.size()], - *profile_table_data, on_done); + Status status = gcs_table_storage_->ProfileTable().Put( + ids_[cursor_++ % ids_.size()], *profile_table_data, on_done); if (!status.ok()) { on_done(status); } @@ -51,7 +51,8 @@ void DefaultStatsHandler::HandleAddProfileData(const AddProfileDataRequest &requ } void DefaultStatsHandler::HandleGetAllProfileInfo( - const rpc::GetAllProfileInfoRequest &request, rpc::GetAllProfileInfoReply *reply, + const rpc::GetAllProfileInfoRequest &request, + rpc::GetAllProfileInfoReply *reply, rpc::SendReplyCallback send_reply_callback) { RAY_LOG(DEBUG) << "Getting all profile info."; auto on_done = [reply, send_reply_callback]( diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc index ee3cc34f5..91e1cc286 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc @@ -42,8 +42,10 @@ class MockActorScheduler : public gcs::GcsActorSchedulerInterface { MOCK_CONST_METHOD0(DebugString, std::string()); MOCK_METHOD1(CancelOnNode, std::vector(const NodeID &node_id)); MOCK_METHOD2(CancelOnWorker, ActorID(const NodeID &node_id, const WorkerID &worker_id)); - MOCK_METHOD3(CancelOnLeasing, void(const NodeID &node_id, const ActorID &actor_id, - const TaskID &task_id)); + MOCK_METHOD3(CancelOnLeasing, + void(const NodeID &node_id, + const ActorID &actor_id, + const TaskID &task_id)); std::vector> actors; }; @@ -116,8 +118,13 @@ class GcsActorManagerTest : public ::testing::Test { kv_ = std::make_unique(); function_manager_ = std::make_unique(*kv_); gcs_actor_manager_ = std::make_unique( - io_service_, mock_actor_scheduler_, gcs_table_storage_, gcs_publisher_, - *runtime_env_mgr_, *function_manager_, [](const ActorID &actor_id) {}, + io_service_, + mock_actor_scheduler_, + gcs_table_storage_, + gcs_publisher_, + *runtime_env_mgr_, + *function_manager_, + [](const ActorID &actor_id) {}, [this](const JobID &job_id) { auto job_config = std::make_shared(); job_config->set_ray_namespace(job_namespace_table_[job_id]); @@ -178,13 +185,14 @@ class GcsActorManagerTest : public ::testing::Test { return address; } - std::shared_ptr RegisterActor(const JobID &job_id, int max_restarts = 0, + std::shared_ptr RegisterActor(const JobID &job_id, + int max_restarts = 0, bool detached = false, const std::string &name = "", const std::string &ray_namespace = "") { std::promise> promise; - auto request = Mocker::GenRegisterActorRequest(job_id, max_restarts, detached, name, - ray_namespace); + auto request = Mocker::GenRegisterActorRequest( + job_id, max_restarts, detached, name, ray_namespace); // `DestroyActor` triggers some asynchronous operations. // If we register an actor after destroying an actor, it may result in multithreading // reading and writing the same variable. In order to avoid the problem of @@ -277,8 +285,9 @@ TEST_F(GcsActorManagerTest, TestSchedulingFailed) { std::vector> finished_actors; RAY_CHECK_OK(gcs_actor_manager_->CreateActor( - create_actor_request, [&finished_actors](std::shared_ptr actor, - const rpc::PushTaskReply &reply) { + create_actor_request, + [&finished_actors](std::shared_ptr actor, + const rpc::PushTaskReply &reply) { finished_actors.emplace_back(actor); })); @@ -310,8 +319,9 @@ TEST_F(GcsActorManagerTest, TestWorkerFailure) { std::vector> finished_actors; RAY_CHECK_OK(gcs_actor_manager_->CreateActor( - create_actor_request, [&finished_actors](std::shared_ptr actor, - const rpc::PushTaskReply &reply) { + create_actor_request, + [&finished_actors](std::shared_ptr actor, + const rpc::PushTaskReply &reply) { finished_actors.emplace_back(actor); })); @@ -357,8 +367,9 @@ TEST_F(GcsActorManagerTest, TestNodeFailure) { std::vector> finished_actors; Status status = gcs_actor_manager_->CreateActor( - create_actor_request, [&finished_actors](std::shared_ptr actor, - const rpc::PushTaskReply &reply) { + create_actor_request, + [&finished_actors](std::shared_ptr actor, + const rpc::PushTaskReply &reply) { finished_actors.emplace_back(actor); }); RAY_CHECK_OK(status); @@ -399,7 +410,8 @@ TEST_F(GcsActorManagerTest, TestNodeFailure) { TEST_F(GcsActorManagerTest, TestActorReconstruction) { auto job_id = JobID::FromInt(1); - auto registered_actor = RegisterActor(job_id, /*max_restarts=*/1, + auto registered_actor = RegisterActor(job_id, + /*max_restarts=*/1, /*detached=*/false); rpc::CreateActorRequest create_actor_request; create_actor_request.mutable_task_spec()->CopyFrom( @@ -407,8 +419,9 @@ TEST_F(GcsActorManagerTest, TestActorReconstruction) { std::vector> finished_actors; Status status = gcs_actor_manager_->CreateActor( - create_actor_request, [&finished_actors](std::shared_ptr actor, - const rpc::PushTaskReply &reply) { + create_actor_request, + [&finished_actors](std::shared_ptr actor, + const rpc::PushTaskReply &reply) { finished_actors.emplace_back(actor); }); RAY_CHECK_OK(status); @@ -467,7 +480,8 @@ TEST_F(GcsActorManagerTest, TestActorReconstruction) { TEST_F(GcsActorManagerTest, TestActorRestartWhenOwnerDead) { auto job_id = JobID::FromInt(1); - auto registered_actor = RegisterActor(job_id, /*max_restarts=*/1, + auto registered_actor = RegisterActor(job_id, + /*max_restarts=*/1, /*detached=*/false); rpc::CreateActorRequest create_actor_request; create_actor_request.mutable_task_spec()->CopyFrom( @@ -475,8 +489,9 @@ TEST_F(GcsActorManagerTest, TestActorRestartWhenOwnerDead) { std::vector> finished_actors; RAY_CHECK_OK(gcs_actor_manager_->CreateActor( - create_actor_request, [&finished_actors](std::shared_ptr actor, - const rpc::PushTaskReply &reply) { + create_actor_request, + [&finished_actors](std::shared_ptr actor, + const rpc::PushTaskReply &reply) { finished_actors.emplace_back(actor); })); @@ -517,7 +532,8 @@ TEST_F(GcsActorManagerTest, TestActorRestartWhenOwnerDead) { TEST_F(GcsActorManagerTest, TestDetachedActorRestartWhenCreatorDead) { auto job_id = JobID::FromInt(1); - auto registered_actor = RegisterActor(job_id, /*max_restarts=*/1, + auto registered_actor = RegisterActor(job_id, + /*max_restarts=*/1, /*detached=*/true); rpc::CreateActorRequest create_actor_request; create_actor_request.mutable_task_spec()->CopyFrom( @@ -525,8 +541,9 @@ TEST_F(GcsActorManagerTest, TestDetachedActorRestartWhenCreatorDead) { std::vector> finished_actors; RAY_CHECK_OK(gcs_actor_manager_->CreateActor( - create_actor_request, [&finished_actors](std::shared_ptr actor, - const rpc::PushTaskReply &reply) { + create_actor_request, + [&finished_actors](std::shared_ptr actor, + const rpc::PushTaskReply &reply) { finished_actors.emplace_back(actor); })); @@ -555,8 +572,10 @@ TEST_F(GcsActorManagerTest, TestActorWithEmptyName) { // Gen `CreateActorRequest` with an empty name. // (name,actor_id) => ("", actor_id_1) - auto request1 = Mocker::GenRegisterActorRequest(job_id, /*max_restarts=*/0, - /*detached=*/true, /*name=*/""); + auto request1 = Mocker::GenRegisterActorRequest(job_id, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/""); Status status = gcs_actor_manager_->RegisterActor( request1, [](std::shared_ptr actor) {}); // Ensure successful registration. @@ -566,8 +585,10 @@ TEST_F(GcsActorManagerTest, TestActorWithEmptyName) { // Gen another `CreateActorRequest` with an empty name. // (name,actor_id) => ("", actor_id_2) - auto request2 = Mocker::GenRegisterActorRequest(job_id, /*max_restarts=*/0, - /*detached=*/true, /*name=*/""); + auto request2 = Mocker::GenRegisterActorRequest(job_id, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/""); status = gcs_actor_manager_->RegisterActor(request2, [](std::shared_ptr actor) {}); // Ensure successful registration. @@ -578,16 +599,20 @@ TEST_F(GcsActorManagerTest, TestNamedActors) { auto job_id_1 = JobID::FromInt(1); auto job_id_2 = JobID::FromInt(2); - auto request1 = Mocker::GenRegisterActorRequest(job_id_1, /*max_restarts=*/0, - /*detached=*/true, /*name=*/"actor1"); + auto request1 = Mocker::GenRegisterActorRequest(job_id_1, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/"actor1"); Status status = gcs_actor_manager_->RegisterActor( request1, [](std::shared_ptr actor) {}); ASSERT_TRUE(status.ok()); ASSERT_EQ(gcs_actor_manager_->GetActorIDByName("actor1", "").Binary(), request1.task_spec().actor_creation_task_spec().actor_id()); - auto request2 = Mocker::GenRegisterActorRequest(job_id_1, /*max_restarts=*/0, - /*detached=*/true, /*name=*/"actor2"); + auto request2 = Mocker::GenRegisterActorRequest(job_id_1, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/"actor2"); status = gcs_actor_manager_->RegisterActor(request2, [](std::shared_ptr actor) {}); ASSERT_TRUE(status.ok()); @@ -598,8 +623,10 @@ TEST_F(GcsActorManagerTest, TestNamedActors) { ASSERT_EQ(gcs_actor_manager_->GetActorIDByName("actor3", ""), ActorID::Nil()); // Check that naming collisions return Status::Invalid. - auto request3 = Mocker::GenRegisterActorRequest(job_id_1, /*max_restarts=*/0, - /*detached=*/true, /*name=*/"actor2"); + auto request3 = Mocker::GenRegisterActorRequest(job_id_1, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/"actor2"); status = gcs_actor_manager_->RegisterActor(request3, [](std::shared_ptr actor) {}); ASSERT_TRUE(status.IsNotFound()); @@ -607,8 +634,10 @@ TEST_F(GcsActorManagerTest, TestNamedActors) { request2.task_spec().actor_creation_task_spec().actor_id()); // Check that naming collisions are enforced across JobIDs. - auto request4 = Mocker::GenRegisterActorRequest(job_id_2, /*max_restarts=*/0, - /*detached=*/true, /*name=*/"actor2"); + auto request4 = Mocker::GenRegisterActorRequest(job_id_2, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/"actor2"); status = gcs_actor_manager_->RegisterActor(request4, [](std::shared_ptr actor) {}); ASSERT_TRUE(status.IsNotFound()); @@ -620,8 +649,10 @@ TEST_F(GcsActorManagerTest, TestNamedActorDeletionWorkerFailure) { // Make sure named actor deletion succeeds when workers fail. const auto actor_name = "actor_to_delete"; const auto job_id_1 = JobID::FromInt(1); - auto registered_actor_1 = RegisterActor(job_id_1, /*max_restarts=*/0, - /*detached=*/true, /*name=*/actor_name); + auto registered_actor_1 = RegisterActor(job_id_1, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/actor_name); rpc::CreateActorRequest request1; request1.mutable_task_spec()->CopyFrom( registered_actor_1->GetActorTableData().task_spec()); @@ -655,8 +686,10 @@ TEST_F(GcsActorManagerTest, TestNamedActorDeletionWorkerFailure) { // Create an actor with the same name. This ensures that the name has been properly // deleted. - auto registered_actor_2 = RegisterActor(job_id_1, /*max_restarts=*/0, - /*detached=*/true, /*name=*/actor_name); + auto registered_actor_2 = RegisterActor(job_id_1, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/actor_name); rpc::CreateActorRequest request2; request2.mutable_task_spec()->CopyFrom( registered_actor_2->GetActorTableData().task_spec()); @@ -672,8 +705,10 @@ TEST_F(GcsActorManagerTest, TestNamedActorDeletionWorkerFailure) { TEST_F(GcsActorManagerTest, TestNamedActorDeletionNodeFailure) { // Make sure named actor deletion succeeds when nodes fail. const auto job_id_1 = JobID::FromInt(1); - auto registered_actor_1 = RegisterActor(job_id_1, /*max_restarts=*/0, - /*detached=*/true, /*name=*/"actor"); + auto registered_actor_1 = RegisterActor(job_id_1, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/"actor"); rpc::CreateActorRequest request1; request1.mutable_task_spec()->CopyFrom( registered_actor_1->GetActorTableData().task_spec()); @@ -706,8 +741,10 @@ TEST_F(GcsActorManagerTest, TestNamedActorDeletionNodeFailure) { // Create an actor with the same name. This ensures that the name has been properly // deleted. - auto registered_actor_2 = RegisterActor(job_id_1, /*max_restarts=*/0, - /*detached=*/true, /*name=*/"actor"); + auto registered_actor_2 = RegisterActor(job_id_1, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/"actor"); rpc::CreateActorRequest request2; request2.mutable_task_spec()->CopyFrom( registered_actor_2->GetActorTableData().task_spec()); @@ -724,8 +761,10 @@ TEST_F(GcsActorManagerTest, TestNamedActorDeletionNotHappendWhenReconstructed) { // Make sure named actor deletion succeeds when nodes fail. const auto job_id_1 = JobID::FromInt(1); // The dead actor will be reconstructed. - auto registered_actor_1 = RegisterActor(job_id_1, /*max_restarts=*/1, - /*detached=*/true, /*name=*/"actor"); + auto registered_actor_1 = RegisterActor(job_id_1, + /*max_restarts=*/1, + /*detached=*/true, + /*name=*/"actor"); rpc::CreateActorRequest request1; request1.mutable_task_spec()->CopyFrom( registered_actor_1->GetActorTableData().task_spec()); @@ -757,8 +796,10 @@ TEST_F(GcsActorManagerTest, TestNamedActorDeletionNotHappendWhenReconstructed) { // It should fail because actor has been reconstructed, and names shouldn't have been // cleaned. const auto job_id_2 = JobID::FromInt(2); - auto request2 = Mocker::GenRegisterActorRequest(job_id_2, /*max_restarts=*/0, - /*detached=*/true, /*name=*/"actor"); + auto request2 = Mocker::GenRegisterActorRequest(job_id_2, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/"actor"); status = gcs_actor_manager_->RegisterActor(request2, [](std::shared_ptr actor) {}); ASSERT_TRUE(status.IsNotFound()); @@ -775,8 +816,9 @@ TEST_F(GcsActorManagerTest, TestDestroyActorBeforeActorCreationCompletes) { std::vector> finished_actors; RAY_CHECK_OK(gcs_actor_manager_->CreateActor( - create_actor_request, [&finished_actors](std::shared_ptr actor, - const rpc::PushTaskReply &reply) { + create_actor_request, + [&finished_actors](std::shared_ptr actor, + const rpc::PushTaskReply &reply) { finished_actors.emplace_back(actor); })); @@ -801,7 +843,8 @@ TEST_F(GcsActorManagerTest, TestDestroyActorBeforeActorCreationCompletes) { TEST_F(GcsActorManagerTest, TestRaceConditionCancelLease) { // Covers a scenario 1 in this PR https://github.com/ray-project/ray/pull/9215. auto job_id = JobID::FromInt(1); - auto registered_actor = RegisterActor(job_id, /*max_restarts=*/1, + auto registered_actor = RegisterActor(job_id, + /*max_restarts=*/1, /*detached=*/false); rpc::CreateActorRequest create_actor_request; create_actor_request.mutable_task_spec()->CopyFrom( @@ -809,8 +852,9 @@ TEST_F(GcsActorManagerTest, TestRaceConditionCancelLease) { std::vector> finished_actors; RAY_CHECK_OK(gcs_actor_manager_->CreateActor( - create_actor_request, [&finished_actors](std::shared_ptr actor, - const rpc::PushTaskReply &reply) { + create_actor_request, + [&finished_actors](std::shared_ptr actor, + const rpc::PushTaskReply &reply) { finished_actors.emplace_back(actor); })); @@ -852,8 +896,9 @@ TEST_F(GcsActorManagerTest, TestRegisterActor) { request.mutable_task_spec()->CopyFrom( registered_actor->GetActorTableData().task_spec()); RAY_CHECK_OK(gcs_actor_manager_->CreateActor( - request, [&finished_actors](std::shared_ptr actor, - const rpc::PushTaskReply &reply) { + request, + [&finished_actors](std::shared_ptr actor, + const rpc::PushTaskReply &reply) { finished_actors.emplace_back(std::move(actor)); })); // Make sure the actor is scheduling. @@ -962,7 +1007,8 @@ TEST_F(GcsActorManagerTest, TestOwnerNodeDieBeforeDetachedActorDependenciesResol TEST_F(GcsActorManagerTest, TestOwnerAndChildDiedAtTheSameTimeRaceCondition) { // When owner and child die at the same time, auto job_id = JobID::FromInt(1); - auto registered_actor = RegisterActor(job_id, /*max_restarts=*/1, + auto registered_actor = RegisterActor(job_id, + /*max_restarts=*/1, /*detached=*/false); rpc::CreateActorRequest create_actor_request; create_actor_request.mutable_task_spec()->CopyFrom( @@ -970,8 +1016,9 @@ TEST_F(GcsActorManagerTest, TestOwnerAndChildDiedAtTheSameTimeRaceCondition) { std::vector> finished_actors; RAY_CHECK_OK(gcs_actor_manager_->CreateActor( - create_actor_request, [&finished_actors](std::shared_ptr actor, - const rpc::PushTaskReply &reply) { + create_actor_request, + [&finished_actors](std::shared_ptr actor, + const rpc::PushTaskReply &reply) { finished_actors.emplace_back(actor); })); auto actor = mock_actor_scheduler_->actors.back(); @@ -1002,8 +1049,10 @@ TEST_F(GcsActorManagerTest, TestRayNamespace) { std::string second_namespace = "another_namespace"; job_namespace_table_[job_id_2] = second_namespace; - auto request1 = Mocker::GenRegisterActorRequest(job_id_1, /*max_restarts=*/0, - /*detached=*/true, /*name=*/"actor"); + auto request1 = Mocker::GenRegisterActorRequest(job_id_1, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/"actor"); { // Create an actor in the empty namespace Status status = gcs_actor_manager_->RegisterActor( @@ -1013,8 +1062,10 @@ TEST_F(GcsActorManagerTest, TestRayNamespace) { request1.task_spec().actor_creation_task_spec().actor_id()); } - auto request2 = Mocker::GenRegisterActorRequest(job_id_2, /*max_restarts=*/0, - /*detached=*/true, /*name=*/"actor"); + auto request2 = Mocker::GenRegisterActorRequest(job_id_2, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/"actor"); { // Create a second actor of the same name. Its job id belongs to a different // namespace though. Status status = gcs_actor_manager_->RegisterActor( @@ -1027,8 +1078,10 @@ TEST_F(GcsActorManagerTest, TestRayNamespace) { request1.task_spec().actor_creation_task_spec().actor_id()); } - auto request3 = Mocker::GenRegisterActorRequest(job_id_3, /*max_restarts=*/0, - /*detached=*/true, /*name=*/"actor"); + auto request3 = Mocker::GenRegisterActorRequest(job_id_3, + /*max_restarts=*/0, + /*detached=*/true, + /*name=*/"actor"); { // Actors from different jobs, in the same namespace should still collide. Status status = gcs_actor_manager_->RegisterActor( request3, [](std::shared_ptr actor) {}); @@ -1081,8 +1134,10 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { google::protobuf::Arena arena; skip_delay_ = false; auto job_id_1 = JobID::FromInt(1); - auto request1 = Mocker::GenRegisterActorRequest(job_id_1, /*max_restarts=*/0, - /*detached=*/false, /*name=*/"actor"); + auto request1 = Mocker::GenRegisterActorRequest(job_id_1, + /*max_restarts=*/0, + /*detached=*/false, + /*name=*/"actor"); Status status = gcs_actor_manager_->RegisterActor( request1, [](std::shared_ptr actor) {}); ASSERT_TRUE(status.ok()); @@ -1100,7 +1155,8 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { auto &reply = *google::protobuf::Arena::CreateMessage(&arena); bool called = false; - auto callback = [&called](Status status, std::function success, + auto callback = [&called](Status status, + std::function success, std::function failure) { called = true; }; gcs_actor_manager_->HandleGetAllActorInfo(request, &reply, callback); @@ -1112,7 +1168,8 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { *google::protobuf::Arena::CreateMessage(&arena); request.set_show_dead_jobs(true); std::promise promise; - auto callback = [&promise](Status status, std::function success, + auto callback = [&promise](Status status, + std::function success, std::function failure) { promise.set_value(); }; gcs_actor_manager_->HandleGetAllActorInfo(request, &reply, callback); promise.get_future().get(); @@ -1126,7 +1183,8 @@ TEST_F(GcsActorManagerTest, TestActorTableDataDelayedGC) { *google::protobuf::Arena::CreateMessage(&arena); request.set_show_dead_jobs(true); std::promise promise; - auto callback = [&promise](Status status, std::function success, + auto callback = [&promise](Status status, + std::function success, std::function failure) { promise.set_value(); }; gcs_actor_manager_->HandleGetAllActorInfo(request, &reply, callback); promise.get_future().get(); diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc index 879ea2109..a085d6393 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc @@ -43,10 +43,13 @@ class GcsActorSchedulerTest : public Test { client_pool = std::make_shared( [this](const rpc::Address &) { return raylet_client; }); actor_scheduler = std::make_unique( - io_context, *actor_table, *gcs_node_manager, + io_context, + *actor_table, + *gcs_node_manager, [this](auto a, auto b, auto c) { schedule_failure_handler(a); }, [this](auto a, const rpc::PushTaskReply) { schedule_success_handler(a); }, - client_pool, [this](const rpc::Address &) { return core_worker_client; }); + client_pool, + [this](const rpc::Address &) { return core_worker_client; }); auto node_info = std::make_shared(); node_info->set_state(rpc::GcsNodeInfo::ALIVE); node_id = NodeID::FromRandom(); diff --git a/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc index 68a41c5b2..d2f75dc87 100644 --- a/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_based_actor_scheduler_test.cc @@ -45,7 +45,10 @@ class GcsBasedActorSchedulerTest : public ::testing::Test { std::make_shared(*gcs_resource_manager_); gcs_actor_scheduler_ = std::make_shared( - io_service_, *gcs_actor_table_, *gcs_node_manager_, gcs_resource_manager_, + io_service_, + *gcs_actor_table_, + *gcs_node_manager_, + gcs_resource_manager_, resource_scheduler, /*schedule_failure_handler=*/ [this](std::shared_ptr actor, @@ -74,9 +77,15 @@ class GcsBasedActorSchedulerTest : public ::testing::Test { auto job_id = JobID::FromInt(1); std::unordered_map required_resources; - auto actor_creating_task_spec = Mocker::GenActorCreationTask( - job_id, /*max_restarts=*/1, /*detached=*/true, /*name=*/"", "", owner_address, - required_resources, required_placement_resources); + auto actor_creating_task_spec = + Mocker::GenActorCreationTask(job_id, + /*max_restarts=*/1, + /*detached=*/true, + /*name=*/"", + "", + owner_address, + required_resources, + required_placement_resources); return std::make_shared(actor_creating_task_spec.GetMessage(), /*ray_namespace=*/""); } @@ -178,8 +187,10 @@ TEST_F(GcsBasedActorSchedulerTest, TestScheduleAndDestroyOneActor) { // Grant a worker, then the actor creation request should be sent to the worker. WorkerID worker_id = WorkerID::FromRandom(); ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), - node->node_manager_port(), worker_id, - node_id, NodeID::Nil())); + node->node_manager_port(), + worker_id, + node_id, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); @@ -261,9 +272,13 @@ TEST_F(GcsBasedActorSchedulerTest, TestRejectedRequestWorkerLeaseReply) { ASSERT_EQ(0, worker_client_->callbacks.size()); // Mock a rejected reply, then the actor will be rescheduled. - ASSERT_TRUE(raylet_client_->GrantWorkerLease( - node1->node_manager_address(), node1->node_manager_port(), WorkerID::FromRandom(), - node_id_1, NodeID::Nil(), Status::OK(), /*rejected=*/true)); + ASSERT_TRUE(raylet_client_->GrantWorkerLease(node1->node_manager_address(), + node1->node_manager_port(), + WorkerID::FromRandom(), + node_id_1, + NodeID::Nil(), + Status::OK(), + /*rejected=*/true)); ASSERT_EQ(2, raylet_client_->num_workers_requested); ASSERT_EQ(1, raylet_client_->callbacks.size()); ASSERT_EQ(0, worker_client_->callbacks.size()); @@ -294,9 +309,12 @@ TEST_F(GcsBasedActorSchedulerTest, TestScheduleRetryWhenLeasing) { ASSERT_EQ(0, gcs_actor_scheduler_->num_retry_leasing_count_); // Mock a IOError reply, then the lease request will retry again. - ASSERT_TRUE(raylet_client_->GrantWorkerLease( - node->node_manager_address(), node->node_manager_port(), WorkerID::FromRandom(), - node_id, NodeID::Nil(), Status::IOError(""))); + ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), + node->node_manager_port(), + WorkerID::FromRandom(), + node_id, + NodeID::Nil(), + Status::IOError(""))); ASSERT_EQ(1, gcs_actor_scheduler_->num_retry_leasing_count_); ASSERT_EQ(2, raylet_client_->num_workers_requested); ASSERT_EQ(1, raylet_client_->callbacks.size()); @@ -305,8 +323,10 @@ TEST_F(GcsBasedActorSchedulerTest, TestScheduleRetryWhenLeasing) { // Grant a worker, then the actor creation request should be sent to the worker. WorkerID worker_id = WorkerID::FromRandom(); ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), - node->node_manager_port(), worker_id, - node_id, NodeID::Nil())); + node->node_manager_port(), + worker_id, + node_id, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); @@ -343,8 +363,10 @@ TEST_F(GcsBasedActorSchedulerTest, TestScheduleRetryWhenCreating) { // Grant a worker, then the actor creation request should be sent to the worker. WorkerID worker_id = WorkerID::FromRandom(); ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), - node->node_manager_port(), worker_id, - node_id, NodeID::Nil())); + node->node_manager_port(), + worker_id, + node_id, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); ASSERT_EQ(0, gcs_actor_scheduler_->num_retry_creating_count_); @@ -394,9 +416,11 @@ TEST_F(GcsBasedActorSchedulerTest, TestNodeFailedWhenLeasing) { ASSERT_EQ(1, raylet_client_->callbacks.size()); // Grant a worker, which will influence nothing. - ASSERT_TRUE(raylet_client_->GrantWorkerLease( - node->node_manager_address(), node->node_manager_port(), WorkerID::FromRandom(), - node_id, NodeID::Nil())); + ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), + node->node_manager_port(), + WorkerID::FromRandom(), + node_id, + NodeID::Nil())); ASSERT_EQ(1, raylet_client_->num_workers_requested); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(0, gcs_actor_scheduler_->num_retry_leasing_count_); @@ -431,9 +455,11 @@ TEST_F(GcsBasedActorSchedulerTest, TestLeasingCancelledWhenLeasing) { ASSERT_EQ(1, raylet_client_->callbacks.size()); // Grant a worker, which will influence nothing. - ASSERT_TRUE(raylet_client_->GrantWorkerLease( - node->node_manager_address(), node->node_manager_port(), WorkerID::FromRandom(), - node_id, NodeID::Nil())); + ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), + node->node_manager_port(), + WorkerID::FromRandom(), + node_id, + NodeID::Nil())); ASSERT_EQ(1, raylet_client_->num_workers_requested); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(0, gcs_actor_scheduler_->num_retry_leasing_count_); @@ -463,9 +489,11 @@ TEST_F(GcsBasedActorSchedulerTest, TestNodeFailedWhenCreating) { ASSERT_EQ(0, worker_client_->callbacks.size()); // Grant a worker, then the actor creation request should be send to the worker. - ASSERT_TRUE(raylet_client_->GrantWorkerLease( - node->node_manager_address(), node->node_manager_port(), WorkerID::FromRandom(), - node_id, NodeID::Nil())); + ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), + node->node_manager_port(), + WorkerID::FromRandom(), + node_id, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); @@ -510,8 +538,10 @@ TEST_F(GcsBasedActorSchedulerTest, TestWorkerFailedWhenCreating) { // Grant a worker, then the actor creation request should be send to the worker. auto worker_id = WorkerID::FromRandom(); ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), - node->node_manager_port(), worker_id, - node_id, NodeID::Nil())); + node->node_manager_port(), + worker_id, + node_id, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); @@ -569,8 +599,10 @@ TEST_F(GcsBasedActorSchedulerTest, TestReschedule) { // Grant a worker, then the actor creation request should be send to the worker. ASSERT_TRUE(raylet_client_->GrantWorkerLease(node1->node_manager_address(), - node1->node_manager_port(), worker_id, - node_id_1, NodeID::Nil())); + node1->node_manager_port(), + worker_id, + node_id_1, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); diff --git a/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc index 1bfafa113..017ff031b 100644 --- a/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc @@ -33,8 +33,10 @@ class MockInMemoryStoreClient : public gcs::InMemoryStoreClient { explicit MockInMemoryStoreClient(instrumented_io_context &main_io_service) : gcs::InMemoryStoreClient(main_io_service) {} - Status AsyncPut(const std::string &table_name, const std::string &key, - const std::string &data, const gcs::StatusCallback &callback) override { + Status AsyncPut(const std::string &table_name, + const std::string &key, + const std::string &data, + const gcs::StatusCallback &callback) override { callback(Status::OK()); return Status::OK(); } @@ -79,8 +81,8 @@ class GcsJobManagerTest : public ::testing::Test { }; TEST_F(GcsJobManagerTest, TestGetJobConfig) { - gcs::GcsJobManager gcs_job_manager(gcs_table_storage_, gcs_publisher_, - runtime_env_manager_, *function_manager_); + gcs::GcsJobManager gcs_job_manager( + gcs_table_storage_, gcs_publisher_, runtime_env_manager_, *function_manager_); auto job_id1 = JobID::FromInt(1); auto job_id2 = JobID::FromInt(2); @@ -91,11 +93,13 @@ TEST_F(GcsJobManagerTest, TestGetJobConfig) { rpc::AddJobReply empty_reply; gcs_job_manager.HandleAddJob( - *add_job_request1, &empty_reply, + *add_job_request1, + &empty_reply, [](Status, std::function, std::function) {}); auto add_job_request2 = Mocker::GenAddJobRequest(job_id2, "namespace_2", 8); gcs_job_manager.HandleAddJob( - *add_job_request2, &empty_reply, + *add_job_request2, + &empty_reply, [](Status, std::function, std::function) {}); auto job_config1 = gcs_job_manager.GetJobConfig(job_id1); diff --git a/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc index e14ac53d7..ae4f3547d 100644 --- a/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc @@ -97,7 +97,8 @@ TEST_P(GcsKVManagerTest, TestInternalKV) { } } -INSTANTIATE_TEST_SUITE_P(GcsKVManagerTestFixture, GcsKVManagerTest, +INSTANTIATE_TEST_SUITE_P(GcsKVManagerTestFixture, + GcsKVManagerTest, ::testing::Values("redis", "memory")); int main(int argc, char **argv) { diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc index 9037c7774..8618af7e8 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc @@ -39,9 +39,12 @@ class GcsPlacementGroupManagerMockTest : public Test { resource_manager_ = std::make_shared(io_context_, nullptr, nullptr); - gcs_placement_group_manager_ = std::make_unique( - io_context_, gcs_placement_group_scheduler_, gcs_table_storage_, - *resource_manager_, [](auto &) { return ""; }); + gcs_placement_group_manager_ = + std::make_unique(io_context_, + gcs_placement_group_scheduler_, + gcs_table_storage_, + *resource_manager_, + [](auto &) { return ""; }); } std::unique_ptr gcs_placement_group_manager_; diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc index ef388d3e4..073245594 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc @@ -49,10 +49,11 @@ class MockPlacementGroupScheduler : public gcs::GcsPlacementGroupSchedulerInterf ReleaseUnusedBundles, void(const absl::flat_hash_map> &node_to_bundles)); - MOCK_METHOD1(Initialize, - void(const absl::flat_hash_map< - PlacementGroupID, std::vector>> - &group_to_bundles)); + MOCK_METHOD1( + Initialize, + void(const absl::flat_hash_map>> + &group_to_bundles)); absl::flat_hash_map> GetBundlesOnNode( const NodeID &node_id) override { @@ -82,7 +83,9 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { gcs_resource_manager_ = std::make_shared(io_service_, nullptr, nullptr); gcs_placement_group_manager_.reset(new gcs::GcsPlacementGroupManager( - io_service_, mock_placement_group_scheduler_, gcs_table_storage_, + io_service_, + mock_placement_group_scheduler_, + gcs_table_storage_, *gcs_resource_manager_, [this](const JobID &job_id) { return job_namespace_table_[job_id]; })); for (int i = 1; i <= 10; i++) { @@ -214,8 +217,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestSchedulingFailed) { mock_placement_group_scheduler_->placement_groups_.clear(); ASSERT_EQ(placement_group->GetStats().scheduling_attempt(), 1); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), true); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed( + placement_group, GetExpBackOff(), true); gcs_placement_group_manager_->SchedulePendingPlacementGroups(); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_.size(), 1); @@ -281,8 +284,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeAdd) { mock_placement_group_scheduler_->placement_groups_.pop_back(); // If the creation of placement group fails, it will be rescheduled after a short time. - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), true); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed( + placement_group, GetExpBackOff(), true); WaitForExpectedPgCount(1); } @@ -297,8 +300,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestRemovingPendingPlacementGroup) { auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.clear(); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), true); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed( + placement_group, GetExpBackOff(), true); ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::PENDING); const auto &placement_group_id = placement_group->GetPlacementGroupID(); gcs_placement_group_manager_->RemovePlacementGroup(placement_group_id, @@ -336,8 +339,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestRemovingLeasingPlacementGroup) { gcs_placement_group_manager_->RemovePlacementGroup(placement_group_id, [](const Status &status) {}); ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::REMOVED); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), true); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed( + placement_group, GetExpBackOff(), true); // Make sure it is not rescheduled gcs_placement_group_manager_->SchedulePendingPlacementGroups(); @@ -425,8 +428,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeDead) { placement_group = mock_placement_group_scheduler_->placement_groups_.back(); mock_placement_group_scheduler_->placement_groups_.pop_back(); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_.size(), 0); - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), true); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed( + placement_group, GetExpBackOff(), true); WaitForExpectedPgCount(1); ASSERT_EQ(mock_placement_group_scheduler_->placement_groups_[0]->GetPlacementGroupID(), placement_group->GetPlacementGroupID()); @@ -467,7 +470,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestAutomaticCleanupWhenActorDeadAndJobDead const auto job_id = JobID::FromInt(1); const auto actor_id = ActorID::Of(job_id, TaskID::Nil(), 0); auto request = Mocker::GenCreatePlacementGroupRequest( - /* name */ "", rpc::PlacementStrategy::SPREAD, + /* name */ "", + rpc::PlacementStrategy::SPREAD, /* bundles_count */ 2, /* cpu_num */ 1.0, /* job_id */ job_id, @@ -500,7 +504,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestAutomaticCleanupWhenActorAndJobDead) { const auto job_id = JobID::FromInt(1); const auto actor_id = ActorID::Of(job_id, TaskID::Nil(), 0); auto request = Mocker::GenCreatePlacementGroupRequest( - /* name */ "", rpc::PlacementStrategy::SPREAD, + /* name */ "", + rpc::PlacementStrategy::SPREAD, /* bundles_count */ 2, /* cpu_num */ 1.0, /* job_id */ job_id, @@ -533,7 +538,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestAutomaticCleanupWhenOnlyJobDead) { // Test placement group is cleaned when both actor & job are dead. const auto job_id = JobID::FromInt(1); auto request = Mocker::GenCreatePlacementGroupRequest( - /* name */ "", rpc::PlacementStrategy::SPREAD, + /* name */ "", + rpc::PlacementStrategy::SPREAD, /* bundles_count */ 2, /* cpu_num */ 1.0, /* job_id */ job_id, @@ -563,7 +569,8 @@ TEST_F(GcsPlacementGroupManagerTest, const auto job_id = JobID::FromInt(1); const auto different_job_id = JobID::FromInt(3); auto request = Mocker::GenCreatePlacementGroupRequest( - /* name */ "", rpc::PlacementStrategy::SPREAD, + /* name */ "", + rpc::PlacementStrategy::SPREAD, /* bundles_count */ 2, /* cpu_num */ 1.0, /* job_id */ job_id, @@ -602,8 +609,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestSchedulingCanceledWhenPgIsInfeasible) { mock_placement_group_scheduler_->placement_groups_.clear(); // Mark it non-retryable. - gcs_placement_group_manager_->OnPlacementGroupCreationFailed(placement_group, - GetExpBackOff(), false); + gcs_placement_group_manager_->OnPlacementGroupCreationFailed( + placement_group, GetExpBackOff(), false); ASSERT_EQ(placement_group->GetStats().scheduling_state(), rpc::PlacementGroupStats::INFEASIBLE); diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc index 374da4542..b68c42b7d 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc @@ -43,8 +43,8 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { std::make_unique(redis_client_)); gcs_resource_manager_ = std::make_shared( io_service_, nullptr, gcs_table_storage_); - ray_syncer_ = std::make_shared(io_service_, nullptr, - *gcs_resource_manager_); + ray_syncer_ = std::make_shared( + io_service_, nullptr, *gcs_resource_manager_); gcs_resource_scheduler_ = std::make_shared(*gcs_resource_manager_); store_client_ = std::make_shared(io_service_); @@ -53,8 +53,13 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { gcs_node_manager_ = std::make_shared( gcs_publisher_, gcs_table_storage_, raylet_client_pool_); scheduler_ = std::make_shared( - io_service_, gcs_table_storage_, *gcs_node_manager_, *gcs_resource_manager_, - *gcs_resource_scheduler_, raylet_client_pool_, *ray_syncer_); + io_service_, + gcs_table_storage_, + *gcs_node_manager_, + *gcs_resource_manager_, + *gcs_resource_scheduler_, + raylet_client_pool_, + *ray_syncer_); } void TearDown() override { @@ -234,15 +239,15 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { // Failed to schedule the placement group, because the node resources is not enough. auto request = Mocker::GenCreatePlacementGroupRequest("", strategy); auto placement_group = std::make_shared(request, ""); - scheduler_->ScheduleUnplacedBundles(placement_group, failure_handler, - success_handler); + scheduler_->ScheduleUnplacedBundles( + placement_group, failure_handler, success_handler); WaitPlacementGroupPendingDone(1, GcsPlacementGroupStatus::FAILURE); CheckPlacementGroupSize(0, GcsPlacementGroupStatus::SUCCESS); // A new node is added, and the rescheduling is successful. AddNode(Mocker::GenNodeInfo(0), 2); - scheduler_->ScheduleUnplacedBundles(placement_group, failure_handler, - success_handler); + scheduler_->ScheduleUnplacedBundles( + placement_group, failure_handler, success_handler); ASSERT_TRUE(raylet_clients_[0]->GrantPrepareBundleResources()); WaitPendingDone(raylet_clients_[0]->commit_callbacks, 1); ASSERT_TRUE(raylet_clients_[0]->GrantCommitBundleResources()); @@ -414,8 +419,8 @@ TEST_F(GcsPlacementGroupSchedulerTest, TestStrictPackStrategyBalancedScheduling) auto request = Mocker::GenCreatePlacementGroupRequest("", rpc::PlacementStrategy::STRICT_PACK); auto placement_group = std::make_shared(request, ""); - scheduler_->ScheduleUnplacedBundles(placement_group, failure_handler, - success_handler); + scheduler_->ScheduleUnplacedBundles( + placement_group, failure_handler, success_handler); node_index = !raylet_clients_[0]->lease_callbacks.empty() ? 0 : 1; ++node_select_count[node_index]; diff --git a/src/ray/gcs/gcs_server/test/gcs_resource_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_resource_manager_test.cc index 4c666a784..21c5a0570 100644 --- a/src/ray/gcs/gcs_server/test/gcs_resource_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_resource_manager_test.cc @@ -69,10 +69,10 @@ TEST_F(GcsResourceManagerTest, TestResourceUsageAPI) { auto node_id = NodeID::FromBinary(node->node_id()); rpc::GetAllResourceUsageRequest get_all_request; rpc::GetAllResourceUsageReply get_all_reply; - auto send_reply_callback = [](ray::Status status, std::function f1, - std::function f2) {}; - gcs_resource_manager_->HandleGetAllResourceUsage(get_all_request, &get_all_reply, - send_reply_callback); + auto send_reply_callback = + [](ray::Status status, std::function f1, std::function f2) {}; + gcs_resource_manager_->HandleGetAllResourceUsage( + get_all_request, &get_all_reply, send_reply_callback); ASSERT_EQ(get_all_reply.resource_usage_data().batch().size(), 0); rpc::ReportResourceUsageRequest report_request; @@ -80,14 +80,14 @@ TEST_F(GcsResourceManagerTest, TestResourceUsageAPI) { (*report_request.mutable_resources()->mutable_resources_total())["CPU"] = 2; gcs_resource_manager_->UpdateNodeResourceUsage(node_id, report_request.resources()); - gcs_resource_manager_->HandleGetAllResourceUsage(get_all_request, &get_all_reply, - send_reply_callback); + gcs_resource_manager_->HandleGetAllResourceUsage( + get_all_request, &get_all_reply, send_reply_callback); ASSERT_EQ(get_all_reply.resource_usage_data().batch().size(), 1); gcs_resource_manager_->OnNodeDead(node_id); rpc::GetAllResourceUsageReply get_all_reply2; - gcs_resource_manager_->HandleGetAllResourceUsage(get_all_request, &get_all_reply2, - send_reply_callback); + gcs_resource_manager_->HandleGetAllResourceUsage( + get_all_request, &get_all_reply2, send_reply_callback); ASSERT_EQ(get_all_reply2.resource_usage_data().batch().size(), 0); } diff --git a/src/ray/gcs/gcs_server/test/gcs_resource_report_poller_test.cc b/src/ray/gcs/gcs_server/test/gcs_resource_report_poller_test.cc index 101e45df9..1d6a261c1 100644 --- a/src/ray/gcs/gcs_server/test/gcs_resource_report_poller_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_resource_report_poller_test.cc @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "ray/gcs/gcs_server/gcs_resource_report_poller.h" + #include #include "gtest/gtest.h" -#include "ray/gcs/gcs_server/gcs_resource_report_poller.h" #include "ray/gcs/test/gcs_test_util.h" namespace ray { @@ -28,13 +29,14 @@ class GcsResourceReportPollerTest : public ::testing::Test { GcsResourceReportPollerTest() : current_time_(0), gcs_resource_report_poller_( - nullptr, [](const rpc::ResourcesData &) {}, + nullptr, + [](const rpc::ResourcesData &) {}, [this]() { return current_time_; }, - [this](const rpc::Address &address, - std::shared_ptr &client_pool, - std::function - callback) { + [this]( + const rpc::Address &address, + std::shared_ptr &client_pool, + std::function callback) { if (request_report_) { request_report_(address, client_pool, callback); } @@ -56,7 +58,8 @@ class GcsResourceReportPollerTest : public ::testing::Test { int64_t current_time_; std::function &, + const rpc::Address &, + std::shared_ptr &, std::function)> request_report_; @@ -68,7 +71,8 @@ TEST_F(GcsResourceReportPollerTest, TestBasic) { bool rpc_sent = false; request_report_ = [&rpc_sent]( - const rpc::Address &, std::shared_ptr &, + const rpc::Address &, + std::shared_ptr &, std::function callback) { rpc_sent = true; @@ -99,7 +103,8 @@ TEST_F(GcsResourceReportPollerTest, TestFailedRpc) { bool rpc_sent = false; request_report_ = [&rpc_sent]( - const rpc::Address &, std::shared_ptr &, + const rpc::Address &, + std::shared_ptr &, std::function callback) { RAY_LOG(ERROR) << "Requesting"; @@ -139,7 +144,8 @@ TEST_F(GcsResourceReportPollerTest, TestMaxInFlight) { int num_rpcs_sent = 0; request_report_ = - [&](const rpc::Address &, std::shared_ptr &, + [&](const rpc::Address &, + std::shared_ptr &, std::function callback) { num_rpcs_sent++; @@ -174,7 +180,8 @@ TEST_F(GcsResourceReportPollerTest, TestNodeRemoval) { int num_rpcs_sent = 0; request_report_ = - [&](const rpc::Address &, std::shared_ptr &, + [&](const rpc::Address &, + std::shared_ptr &, std::function callback) { num_rpcs_sent++; @@ -220,7 +227,8 @@ TEST_F(GcsResourceReportPollerTest, TestPrioritizeNewNodes) { int num_rpcs_sent = 0; request_report_ = - [&](const rpc::Address &address, std::shared_ptr &, + [&](const rpc::Address &address, + std::shared_ptr &, std::function callback) { num_rpcs_sent++; diff --git a/src/ray/gcs/gcs_server/test/gcs_resource_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_resource_scheduler_test.cc index 304c4ddd4..8a8b852ad 100644 --- a/src/ray/gcs/gcs_server/test/gcs_resource_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_resource_scheduler_test.cc @@ -38,7 +38,8 @@ class GcsResourceSchedulerTest : public ::testing::Test { gcs_resource_manager_.reset(); } - void AddClusterResources(const NodeID &node_id, const std::string &resource_name, + void AddClusterResources(const NodeID &node_id, + const std::string &resource_name, double resource_value) { auto node = Mocker::GenNodeInfo(); node->set_node_id(node_id.Binary()); @@ -200,15 +201,17 @@ TEST_F(GcsResourceSchedulerTest, TestNodeFilter) { required_resources_list.emplace_back( ResourceMapToResourceRequest(resource_map, /*requires_object_store_memory=*/false)); const auto &result1 = gcs_resource_scheduler_->Schedule( - required_resources_list, gcs::SchedulingType::STRICT_SPREAD, - [](const NodeID &) { return false; }); + required_resources_list, gcs::SchedulingType::STRICT_SPREAD, [](const NodeID &) { + return false; + }); ASSERT_TRUE(result1.first == gcs::SchedulingResultStatus::INFEASIBLE); ASSERT_EQ(result1.second.size(), 0); // Scheduling succeeded. const auto &result2 = gcs_resource_scheduler_->Schedule( - required_resources_list, gcs::SchedulingType::STRICT_SPREAD, - [](const NodeID &) { return true; }); + required_resources_list, gcs::SchedulingType::STRICT_SPREAD, [](const NodeID &) { + return true; + }); ASSERT_TRUE(result2.first == gcs::SchedulingResultStatus::SUCCESS); ASSERT_EQ(result2.second.size(), 1); } diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index 5308c0fc7..fe7863bae 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -75,11 +75,12 @@ class GcsServerTest : public ::testing::Test { bool MarkJobFinished(const rpc::MarkJobFinishedRequest &request) { std::promise promise; - client_->MarkJobFinished(request, [&promise](const Status &status, - const rpc::MarkJobFinishedReply &reply) { - RAY_CHECK_OK(status); - promise.set_value(true); - }); + client_->MarkJobFinished( + request, + [&promise](const Status &status, const rpc::MarkJobFinishedReply &reply) { + RAY_CHECK_OK(status); + promise.set_value(true); + }); return WaitReady(promise.get_future(), timeout_ms_); } @@ -88,17 +89,17 @@ class GcsServerTest : public ::testing::Test { request.set_actor_id(actor_id); boost::optional actor_table_data_opt; std::promise promise; - client_->GetActorInfo( - request, [&actor_table_data_opt, &promise](const Status &status, - const rpc::GetActorInfoReply &reply) { - RAY_CHECK_OK(status); - if (reply.has_actor_table_data()) { - actor_table_data_opt = reply.actor_table_data(); - } else { - actor_table_data_opt = boost::none; - } - promise.set_value(true); - }); + client_->GetActorInfo(request, + [&actor_table_data_opt, &promise]( + const Status &status, const rpc::GetActorInfoReply &reply) { + RAY_CHECK_OK(status); + if (reply.has_actor_table_data()) { + actor_table_data_opt = reply.actor_table_data(); + } else { + actor_table_data_opt = boost::none; + } + promise.set_value(true); + }); EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); return actor_table_data_opt; } @@ -129,8 +130,9 @@ class GcsServerTest : public ::testing::Test { rpc::GetAllNodeInfoRequest request; std::promise promise; client_->GetAllNodeInfo( - request, [&node_info_list, &promise](const Status &status, - const rpc::GetAllNodeInfoReply &reply) { + request, + [&node_info_list, &promise](const Status &status, + const rpc::GetAllNodeInfoReply &reply) { RAY_CHECK_OK(status); for (int index = 0; index < reply.node_info_list_size(); ++index) { node_info_list.push_back(reply.node_info_list(index)); @@ -143,11 +145,12 @@ class GcsServerTest : public ::testing::Test { bool ReportHeartbeat(const rpc::ReportHeartbeatRequest &request) { std::promise promise; - client_->ReportHeartbeat(request, [&promise](const Status &status, - const rpc::ReportHeartbeatReply &reply) { - RAY_CHECK_OK(status); - promise.set_value(true); - }); + client_->ReportHeartbeat( + request, + [&promise](const Status &status, const rpc::ReportHeartbeatReply &reply) { + RAY_CHECK_OK(status); + promise.set_value(true); + }); return WaitReady(promise.get_future(), timeout_ms_); } @@ -196,8 +199,9 @@ class GcsServerTest : public ::testing::Test { boost::optional worker_table_data_opt; std::promise promise; client_->GetWorkerInfo( - request, [&worker_table_data_opt, &promise]( - const Status &status, const rpc::GetWorkerInfoReply &reply) { + request, + [&worker_table_data_opt, &promise](const Status &status, + const rpc::GetWorkerInfoReply &reply) { RAY_CHECK_OK(status); if (reply.has_worker_table_data()) { worker_table_data_opt = reply.worker_table_data(); @@ -215,8 +219,9 @@ class GcsServerTest : public ::testing::Test { rpc::GetAllWorkerInfoRequest request; std::promise promise; client_->GetAllWorkerInfo( - request, [&worker_table_data, &promise](const Status &status, - const rpc::GetAllWorkerInfoReply &reply) { + request, + [&worker_table_data, &promise](const Status &status, + const rpc::GetAllWorkerInfoReply &reply) { RAY_CHECK_OK(status); for (int index = 0; index < reply.worker_table_data_size(); ++index) { worker_table_data.push_back(reply.worker_table_data(index)); diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h index 4b0bd55d3..4f45437c7 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -61,8 +61,10 @@ struct GcsServerMocker { class MockRayletClient : public RayletClientInterface { public: /// WorkerLeaseInterface - ray::Status ReturnWorker(int worker_port, const WorkerID &worker_id, - bool disconnect_worker, bool worker_exiting) override { + ray::Status ReturnWorker(int worker_port, + const WorkerID &worker_id, + bool disconnect_worker, + bool worker_exiting) override { if (disconnect_worker) { num_workers_disconnected++; } else { @@ -77,9 +79,11 @@ struct GcsServerMocker { /// WorkerLeaseInterface void RequestWorkerLease( - const rpc::TaskSpec &spec, bool grant_or_reject, + const rpc::TaskSpec &spec, + bool grant_or_reject, const rpc::ClientCallback &callback, - const int64_t backlog_size, const bool is_selected_based_on_locality) override { + const int64_t backlog_size, + const bool is_selected_based_on_locality) override { num_workers_requested += 1; callbacks.push_back(callback); } @@ -105,9 +109,13 @@ struct GcsServerMocker { } // Trigger reply to RequestWorkerLease. - bool GrantWorkerLease(const std::string &address, int port, const WorkerID &worker_id, - const NodeID &raylet_id, const NodeID &retry_at_raylet_id, - Status status = Status::OK(), bool rejected = false) { + bool GrantWorkerLease(const std::string &address, + int port, + const WorkerID &worker_id, + const NodeID &raylet_id, + const NodeID &retry_at_raylet_id, + Status status = Status::OK(), + bool rejected = false) { rpc::RequestWorkerLeaseReply reply; if (!retry_at_raylet_id.IsNil()) { reply.mutable_retry_at_raylet_address()->set_ip_address(address); @@ -243,7 +251,8 @@ struct GcsServerMocker { /// PinObjectsInterface void PinObjectIDs( - const rpc::Address &caller_address, const std::vector &object_ids, + const rpc::Address &caller_address, + const std::vector &object_ids, const ray::rpc::ClientCallback &callback) override {} /// DependencyWaiterInterface @@ -274,7 +283,8 @@ struct GcsServerMocker { /// ShutdownRaylet void ShutdownRaylet( - const NodeID &node_id, bool graceful, + const NodeID &node_id, + bool graceful, const rpc::ClientCallback &callback) override{}; ~MockRayletClient() {} @@ -365,7 +375,8 @@ struct GcsServerMocker { MockedGcsActorTable(std::shared_ptr store_client) : GcsActorTable(store_client) {} - Status Put(const ActorID &key, const rpc::ActorTableData &value, + Status Put(const ActorID &key, + const rpc::ActorTableData &value, const gcs::StatusCallback &callback) override { auto status = Status::OK(); callback(status); @@ -449,8 +460,10 @@ struct GcsServerMocker { MockGcsPubSub(std::shared_ptr redis_client) : GcsPubSub(redis_client) {} - Status Publish(std::string_view channel, const std::string &id, - const std::string &data, const gcs::StatusCallback &done) override { + Status Publish(std::string_view channel, + const std::string &id, + const std::string &data, + const gcs::StatusCallback &done) override { return Status::OK(); } }; diff --git a/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h b/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h index 8e1c472a8..79cf1ace1 100644 --- a/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h +++ b/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h @@ -127,7 +127,9 @@ class GcsTableStorageTestBase : public ::testing::Test { } template - int GetByJobId(TABLE &table, const JobID &job_id, const KEY &key, + int GetByJobId(TABLE &table, + const JobID &job_id, + const KEY &key, std::vector &values) { auto on_done = [this, &values](const absl::flat_hash_map &result) { values.clear(); diff --git a/src/ray/gcs/gcs_server/test/grpc_based_resource_broadcaster_test.cc b/src/ray/gcs/gcs_server/test/grpc_based_resource_broadcaster_test.cc index 0a931e4f3..a1e4be712 100644 --- a/src/ray/gcs/gcs_server/test/grpc_based_resource_broadcaster_test.cc +++ b/src/ray/gcs/gcs_server/test/grpc_based_resource_broadcaster_test.cc @@ -33,7 +33,8 @@ class GrpcBasedResourceBroadcasterTest : public ::testing::Test { /*raylet_client_pool*/ nullptr, /*send_batch*/ [this](const rpc::Address &address, - std::shared_ptr &pool, std::string &data, + std::shared_ptr &pool, + std::string &data, const rpc::ClientCallback &callback) { num_batches_sent_++; callbacks_.push_back(callback); diff --git a/src/ray/gcs/gcs_server/test/raylet_based_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/raylet_based_actor_scheduler_test.cc index eae8b14d2..083ded46f 100644 --- a/src/ray/gcs/gcs_server/test/raylet_based_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/raylet_based_actor_scheduler_test.cc @@ -37,7 +37,9 @@ class RayletBasedActorSchedulerTest : public ::testing::Test { std::make_shared(store_client_); gcs_actor_scheduler_ = std::make_shared( - io_service_, *gcs_actor_table_, *gcs_node_manager_, + io_service_, + *gcs_actor_table_, + *gcs_node_manager_, /*schedule_failure_handler=*/ [this](std::shared_ptr actor, const rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, @@ -109,8 +111,10 @@ TEST_F(RayletBasedActorSchedulerTest, TestScheduleActorSuccess) { // Grant a worker, then the actor creation request should be send to the worker. WorkerID worker_id = WorkerID::FromRandom(); ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), - node->node_manager_port(), worker_id, - node_id, NodeID::Nil())); + node->node_manager_port(), + worker_id, + node_id, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); @@ -143,9 +147,12 @@ TEST_F(RayletBasedActorSchedulerTest, TestScheduleRetryWhenLeasing) { ASSERT_EQ(0, gcs_actor_scheduler_->num_retry_leasing_count_); // Mock a IOError reply, then the lease request will retry again. - ASSERT_TRUE(raylet_client_->GrantWorkerLease( - node->node_manager_address(), node->node_manager_port(), WorkerID::FromRandom(), - node_id, NodeID::Nil(), Status::IOError(""))); + ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), + node->node_manager_port(), + WorkerID::FromRandom(), + node_id, + NodeID::Nil(), + Status::IOError(""))); ASSERT_EQ(1, gcs_actor_scheduler_->num_retry_leasing_count_); ASSERT_EQ(2, raylet_client_->num_workers_requested); ASSERT_EQ(1, raylet_client_->callbacks.size()); @@ -154,8 +161,10 @@ TEST_F(RayletBasedActorSchedulerTest, TestScheduleRetryWhenLeasing) { // Grant a worker, then the actor creation request should be send to the worker. WorkerID worker_id = WorkerID::FromRandom(); ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), - node->node_manager_port(), worker_id, - node_id, NodeID::Nil())); + node->node_manager_port(), + worker_id, + node_id, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); @@ -189,8 +198,10 @@ TEST_F(RayletBasedActorSchedulerTest, TestScheduleRetryWhenCreating) { // Grant a worker, then the actor creation request should be send to the worker. WorkerID worker_id = WorkerID::FromRandom(); ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), - node->node_manager_port(), worker_id, - node_id, NodeID::Nil())); + node->node_manager_port(), + worker_id, + node_id, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); ASSERT_EQ(0, gcs_actor_scheduler_->num_retry_creating_count_); @@ -237,9 +248,11 @@ TEST_F(RayletBasedActorSchedulerTest, TestNodeFailedWhenLeasing) { ASSERT_EQ(1, raylet_client_->callbacks.size()); // Grant a worker, which will influence nothing. - ASSERT_TRUE(raylet_client_->GrantWorkerLease( - node->node_manager_address(), node->node_manager_port(), WorkerID::FromRandom(), - node_id, NodeID::Nil())); + ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), + node->node_manager_port(), + WorkerID::FromRandom(), + node_id, + NodeID::Nil())); ASSERT_EQ(1, raylet_client_->num_workers_requested); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(0, gcs_actor_scheduler_->num_retry_leasing_count_); @@ -271,9 +284,11 @@ TEST_F(RayletBasedActorSchedulerTest, TestLeasingCancelledWhenLeasing) { ASSERT_EQ(1, raylet_client_->callbacks.size()); // Grant a worker, which will influence nothing. - ASSERT_TRUE(raylet_client_->GrantWorkerLease( - node->node_manager_address(), node->node_manager_port(), WorkerID::FromRandom(), - node_id, NodeID::Nil())); + ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), + node->node_manager_port(), + WorkerID::FromRandom(), + node_id, + NodeID::Nil())); ASSERT_EQ(1, raylet_client_->num_workers_requested); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(0, gcs_actor_scheduler_->num_retry_leasing_count_); @@ -300,9 +315,11 @@ TEST_F(RayletBasedActorSchedulerTest, TestNodeFailedWhenCreating) { ASSERT_EQ(0, worker_client_->callbacks.size()); // Grant a worker, then the actor creation request should be send to the worker. - ASSERT_TRUE(raylet_client_->GrantWorkerLease( - node->node_manager_address(), node->node_manager_port(), WorkerID::FromRandom(), - node_id, NodeID::Nil())); + ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), + node->node_manager_port(), + WorkerID::FromRandom(), + node_id, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); @@ -344,8 +361,10 @@ TEST_F(RayletBasedActorSchedulerTest, TestWorkerFailedWhenCreating) { // Grant a worker, then the actor creation request should be send to the worker. auto worker_id = WorkerID::FromRandom(); ASSERT_TRUE(raylet_client_->GrantWorkerLease(node->node_manager_address(), - node->node_manager_port(), worker_id, - node_id, NodeID::Nil())); + node->node_manager_port(), + worker_id, + node_id, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); @@ -388,9 +407,11 @@ TEST_F(RayletBasedActorSchedulerTest, TestSpillback) { // Grant with an invalid spillback node, and schedule again. auto invalid_node_id = NodeID::FromBinary(Mocker::GenNodeInfo()->node_id()); - ASSERT_TRUE(raylet_client_->GrantWorkerLease( - node2->node_manager_address(), node2->node_manager_port(), WorkerID::Nil(), - node_id_1, invalid_node_id)); + ASSERT_TRUE(raylet_client_->GrantWorkerLease(node2->node_manager_address(), + node2->node_manager_port(), + WorkerID::Nil(), + node_id_1, + invalid_node_id)); ASSERT_EQ(2, raylet_client_->num_workers_requested); ASSERT_EQ(1, raylet_client_->callbacks.size()); ASSERT_EQ(0, worker_client_->callbacks.size()); @@ -399,7 +420,9 @@ TEST_F(RayletBasedActorSchedulerTest, TestSpillback) { // node2. ASSERT_TRUE(raylet_client_->GrantWorkerLease(node2->node_manager_address(), node2->node_manager_port(), - WorkerID::Nil(), node_id_1, node_id_2)); + WorkerID::Nil(), + node_id_1, + node_id_2)); ASSERT_EQ(3, raylet_client_->num_workers_requested); ASSERT_EQ(1, raylet_client_->callbacks.size()); ASSERT_EQ(0, worker_client_->callbacks.size()); @@ -407,8 +430,10 @@ TEST_F(RayletBasedActorSchedulerTest, TestSpillback) { // Grant a worker, then the actor creation request should be send to the worker. WorkerID worker_id = WorkerID::FromRandom(); ASSERT_TRUE(raylet_client_->GrantWorkerLease(node2->node_manager_address(), - node2->node_manager_port(), worker_id, - node_id_2, NodeID::Nil())); + node2->node_manager_port(), + worker_id, + node_id_2, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); @@ -459,8 +484,10 @@ TEST_F(RayletBasedActorSchedulerTest, TestReschedule) { // Grant a worker, then the actor creation request should be send to the worker. ASSERT_TRUE(raylet_client_->GrantWorkerLease(node1->node_manager_address(), - node1->node_manager_port(), worker_id, - node_id_1, NodeID::Nil())); + node1->node_manager_port(), + worker_id, + node_id_1, + NodeID::Nil())); ASSERT_EQ(0, raylet_client_->callbacks.size()); ASSERT_EQ(1, worker_client_->callbacks.size()); diff --git a/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc b/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc index 0afe9ce32..dc60fa63e 100644 --- a/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc +++ b/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc @@ -27,7 +27,9 @@ class RedisGcsTableStorageTest : public gcs::GcsTableStorageTestBase { static void TearDownTestCase() { TestSetupUtil::ShutDownRedisServers(); } void SetUp() override { - gcs::RedisClientOptions options("127.0.0.1", TEST_REDIS_SERVER_PORTS.front(), "", + gcs::RedisClientOptions options("127.0.0.1", + TEST_REDIS_SERVER_PORTS.front(), + "", /*enable_sharding_conn=*/false); redis_client_ = std::make_shared(options); RAY_CHECK_OK(redis_client_->Connect(io_service_pool_->GetAll())); diff --git a/src/ray/gcs/pb_util.h b/src/ray/gcs/pb_util.h index d1ccbee7f..01807f935 100644 --- a/src/ray/gcs/pb_util.h +++ b/src/ray/gcs/pb_util.h @@ -36,8 +36,11 @@ using ContextCase = rpc::ActorDeathCause::ContextCase; /// \param driver_pid Process ID of the driver running this job. /// \return The job table data created by this method. inline std::shared_ptr CreateJobTableData( - const ray::JobID &job_id, bool is_dead, const std::string &driver_ip_address, - int64_t driver_pid, const ray::rpc::JobConfig &job_config = {}) { + const ray::JobID &job_id, + bool is_dead, + const std::string &driver_ip_address, + int64_t driver_pid, + const ray::rpc::JobConfig &job_config = {}) { auto job_info_ptr = std::make_shared(); job_info_ptr->set_job_id(job_id.Binary()); job_info_ptr->set_is_dead(is_dead); @@ -49,7 +52,9 @@ inline std::shared_ptr CreateJobTableData( /// Helper function to produce error table data. inline std::shared_ptr CreateErrorTableData( - const std::string &error_type, const std::string &error_msg, double timestamp, + const std::string &error_type, + const std::string &error_msg, + double timestamp, const JobID &job_id = JobID::Nil()) { uint32_t max_error_msg_size_bytes = RayConfig::instance().max_error_msg_size_bytes(); auto error_info_ptr = std::make_shared(); @@ -70,8 +75,10 @@ inline std::shared_ptr CreateErrorTableData( /// Helper function to produce actor table data. inline std::shared_ptr CreateActorTableData( - const TaskSpecification &task_spec, const ray::rpc::Address &address, - ray::rpc::ActorTableData::ActorState state, uint64_t num_restarts) { + const TaskSpecification &task_spec, + const ray::rpc::Address &address, + ray::rpc::ActorTableData::ActorState state, + uint64_t num_restarts) { RAY_CHECK(task_spec.IsActorCreationTask()); auto actor_id = task_spec.ActorCreationId(); auto actor_info_ptr = std::make_shared(); @@ -95,8 +102,12 @@ inline std::shared_ptr CreateActorTableData( /// Helper function to produce worker failure data. inline std::shared_ptr CreateWorkerFailureData( - const NodeID &raylet_id, const WorkerID &worker_id, const std::string &address, - int32_t port, int64_t timestamp, rpc::WorkerExitType disconnect_type, + const NodeID &raylet_id, + const WorkerID &worker_id, + const std::string &address, + int32_t port, + int64_t timestamp, + rpc::WorkerExitType disconnect_type, const rpc::RayException *creation_task_exception = nullptr) { auto worker_failure_info_ptr = std::make_shared(); worker_failure_info_ptr->mutable_worker_address()->set_raylet_id(raylet_id.Binary()); diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.cc b/src/ray/gcs/pubsub/gcs_pub_sub.cc index 89ce22946..342f03ba5 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.cc +++ b/src/ray/gcs/pubsub/gcs_pub_sub.cc @@ -19,8 +19,10 @@ namespace ray { namespace gcs { -Status GcsPubSub::Publish(std::string_view channel, const std::string &id, - const std::string &data, const StatusCallback &done) { +Status GcsPubSub::Publish(std::string_view channel, + const std::string &id, + const std::string &data, + const StatusCallback &done) { rpc::PubSubMessage message; message.set_id(id); message.set_data(data); @@ -35,12 +37,15 @@ Status GcsPubSub::Publish(std::string_view channel, const std::string &id, GenChannelPattern(channel, id), message.SerializeAsString(), on_done); } -Status GcsPubSub::Subscribe(std::string_view channel, const std::string &id, - const Callback &subscribe, const StatusCallback &done) { +Status GcsPubSub::Subscribe(std::string_view channel, + const std::string &id, + const Callback &subscribe, + const StatusCallback &done) { return SubscribeInternal(channel, subscribe, done, id); } -Status GcsPubSub::SubscribeAll(std::string_view channel, const Callback &subscribe, +Status GcsPubSub::SubscribeAll(std::string_view channel, + const Callback &subscribe, const StatusCallback &done) { return SubscribeInternal(channel, subscribe, done, std::nullopt); } @@ -60,7 +65,8 @@ Status GcsPubSub::Unsubscribe(std::string_view channel_name, const std::string & } Status GcsPubSub::SubscribeInternal(std::string_view channel_name, - const Callback &subscribe, const StatusCallback &done, + const Callback &subscribe, + const StatusCallback &done, const std::optional &id) { std::string pattern = GenChannelPattern(channel_name, id); @@ -93,7 +99,10 @@ Status GcsPubSub::ExecuteCommandIfPossible(const std::string &channel_key, ray::gcs::RedisCallbackManager::instance().AllocateCallbackIndex(); const auto &command_done_callback = command.done_callback; const auto &command_subscribe_callback = command.subscribe_callback; - auto callback = [this, channel_key, command_done_callback, command_subscribe_callback, + auto callback = [this, + channel_key, + command_done_callback, + command_subscribe_callback, callback_index](std::shared_ptr reply) { if (reply->IsNil()) { return; @@ -146,11 +155,11 @@ Status GcsPubSub::ExecuteCommandIfPossible(const std::string &channel_key, }; if (command.is_sub_or_unsub_all) { - status = redis_client_->GetPrimaryContext()->PSubscribeAsync(channel_key, callback, - callback_index); + status = redis_client_->GetPrimaryContext()->PSubscribeAsync( + channel_key, callback, callback_index); } else { - status = redis_client_->GetPrimaryContext()->SubscribeAsync(channel_key, callback, - callback_index); + status = redis_client_->GetPrimaryContext()->SubscribeAsync( + channel_key, callback, callback_index); } channel.pending_reply = true; channel.command_queue.pop_front(); @@ -205,7 +214,8 @@ std::string GcsPubSub::DebugString() const { return stream.str(); } -Status GcsPublisher::PublishActor(const ActorID &id, const rpc::ActorTableData &message, +Status GcsPublisher::PublishActor(const ActorID &id, + const rpc::ActorTableData &message, const StatusCallback &done) { if (publisher_ != nullptr) { rpc::PubMessage msg; @@ -221,7 +231,8 @@ Status GcsPublisher::PublishActor(const ActorID &id, const rpc::ActorTableData & return pubsub_->Publish(ACTOR_CHANNEL, id.Hex(), message.SerializeAsString(), done); } -Status GcsPublisher::PublishJob(const JobID &id, const rpc::JobTableData &message, +Status GcsPublisher::PublishJob(const JobID &id, + const rpc::JobTableData &message, const StatusCallback &done) { if (publisher_ != nullptr) { rpc::PubMessage msg; @@ -237,7 +248,8 @@ Status GcsPublisher::PublishJob(const JobID &id, const rpc::JobTableData &messag return pubsub_->Publish(JOB_CHANNEL, id.Hex(), message.SerializeAsString(), done); } -Status GcsPublisher::PublishNodeInfo(const NodeID &id, const rpc::GcsNodeInfo &message, +Status GcsPublisher::PublishNodeInfo(const NodeID &id, + const rpc::GcsNodeInfo &message, const StatusCallback &done) { if (publisher_ != nullptr) { rpc::PubMessage msg; @@ -267,8 +279,8 @@ Status GcsPublisher::PublishNodeResource(const NodeID &id, } return Status::OK(); } - return pubsub_->Publish(NODE_RESOURCE_CHANNEL, id.Hex(), message.SerializeAsString(), - done); + return pubsub_->Publish( + NODE_RESOURCE_CHANNEL, id.Hex(), message.SerializeAsString(), done); } Status GcsPublisher::PublishResourceBatch(const rpc::ResourceUsageBatchData &message, @@ -333,7 +345,8 @@ Status GcsSubscriber::SubscribeAllJobs( RAY_LOG(WARNING) << "Subscription to Job channel failed: " << status.ToString(); }; if (!subscriber_->SubscribeChannel( - std::make_unique(), rpc::ChannelType::GCS_JOB_CHANNEL, + std::make_unique(), + rpc::ChannelType::GCS_JOB_CHANNEL, gcs_address_, [done](Status status) { if (done != nullptr) { @@ -359,7 +372,8 @@ Status GcsSubscriber::SubscribeAllJobs( } Status GcsSubscriber::SubscribeActor( - const ActorID &id, const SubscribeCallback &subscribe, + const ActorID &id, + const SubscribeCallback &subscribe, const StatusCallback &done) { RAY_CHECK(subscribe != nullptr); if (subscriber_ != nullptr) { @@ -376,14 +390,17 @@ Status GcsSubscriber::SubscribeActor( << " failed: " << status.ToString(); }; if (!subscriber_->Subscribe( - std::make_unique(), rpc::ChannelType::GCS_ACTOR_CHANNEL, - gcs_address_, id.Binary(), + std::make_unique(), + rpc::ChannelType::GCS_ACTOR_CHANNEL, + gcs_address_, + id.Binary(), [done](Status status) { if (done != nullptr) { done(status); } }, - std::move(subscription_callback), std::move(subscription_failure_callback))) { + std::move(subscription_callback), + std::move(subscription_failure_callback))) { return Status::ObjectExists( "Actor already subscribed. Please unsubscribe first if it needs to be " "resubscribed."); @@ -402,8 +419,8 @@ Status GcsSubscriber::SubscribeActor( Status GcsSubscriber::UnsubscribeActor(const ActorID &id) { if (subscriber_ != nullptr) { - subscriber_->Unsubscribe(rpc::ChannelType::GCS_ACTOR_CHANNEL, gcs_address_, - id.Binary()); + subscriber_->Unsubscribe( + rpc::ChannelType::GCS_ACTOR_CHANNEL, gcs_address_, id.Binary()); return Status::OK(); } return pubsub_->Unsubscribe(ACTOR_CHANNEL, id.Hex()); @@ -411,8 +428,8 @@ Status GcsSubscriber::UnsubscribeActor(const ActorID &id) { bool GcsSubscriber::IsActorUnsubscribed(const ActorID &id) { if (subscriber_ != nullptr) { - return !subscriber_->IsSubscribed(rpc::ChannelType::GCS_ACTOR_CHANNEL, gcs_address_, - id.Binary()); + return !subscriber_->IsSubscribed( + rpc::ChannelType::GCS_ACTOR_CHANNEL, gcs_address_, id.Binary()); } return pubsub_->IsUnsubscribed(ACTOR_CHANNEL, id.Hex()); } @@ -431,7 +448,8 @@ Status GcsSubscriber::SubscribeAllNodeInfo( << status.ToString(); }; if (!subscriber_->SubscribeChannel( - std::make_unique(), rpc::ChannelType::GCS_NODE_INFO_CHANNEL, + std::make_unique(), + rpc::ChannelType::GCS_NODE_INFO_CHANNEL, gcs_address_, [done](Status status) { if (done != nullptr) { @@ -471,7 +489,8 @@ Status GcsSubscriber::SubscribeAllNodeResources( }; if (!subscriber_->SubscribeChannel( std::make_unique(), - rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL, gcs_address_, + rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL, + gcs_address_, [done](Status status) { if (done != nullptr) { @@ -526,7 +545,8 @@ Status GcsSubscriber::SubscribeAllWorkerFailures( }; if (!subscriber_->SubscribeChannel( std::make_unique(), - rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL, gcs_address_, + rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL, + gcs_address_, /*subscribe_done_callback=*/ [done](Status status) { if (done != nullptr) { diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.h b/src/ray/gcs/pubsub/gcs_pub_sub.h index 9b261134d..15ff684f0 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.h +++ b/src/ray/gcs/pubsub/gcs_pub_sub.h @@ -64,8 +64,10 @@ class GcsPubSub { /// \param data The data of message to be published to redis. /// \param done Callback that will be called when the message is published to redis. /// \return Status - virtual Status Publish(std::string_view channel, const std::string &id, - const std::string &data, const StatusCallback &done); + virtual Status Publish(std::string_view channel, + const std::string &id, + const std::string &data, + const StatusCallback &done); /// Subscribe to messages with the specified ID under the specified channel. /// @@ -75,8 +77,10 @@ class GcsPubSub { /// received. /// \param done Callback that will be called when subscription is complete. /// \return Status - Status Subscribe(std::string_view channel, const std::string &id, - const Callback &subscribe, const StatusCallback &done); + Status Subscribe(std::string_view channel, + const std::string &id, + const Callback &subscribe, + const StatusCallback &done); /// Subscribe to messages with the specified channel. /// @@ -85,7 +89,8 @@ class GcsPubSub { /// received. /// \param done Callback that will be called when subscription is complete. /// \return Status - Status SubscribeAll(std::string_view channel, const Callback &subscribe, + Status SubscribeAll(std::string_view channel, + const Callback &subscribe, const StatusCallback &done); /// Unsubscribe to messages with the specified ID under the specified channel. @@ -112,7 +117,8 @@ class GcsPubSub { /// channel. struct Command { /// SUBSCRIBE constructor. - Command(const Callback &subscribe_callback, const StatusCallback &done_callback, + Command(const Callback &subscribe_callback, + const StatusCallback &done_callback, bool is_sub_or_unsub_all) : is_subscribe(true), subscribe_callback(subscribe_callback), @@ -163,7 +169,8 @@ class GcsPubSub { GcsPubSub::Channel &channel) EXCLUSIVE_LOCKS_REQUIRED(mutex_); - Status SubscribeInternal(std::string_view channel_name, const Callback &subscribe, + Status SubscribeInternal(std::string_view channel_name, + const Callback &subscribe, const StatusCallback &done, const std::optional &id); @@ -213,28 +220,34 @@ class GcsPublisher { /// TODO: Implement optimization for channels where only latest data per ID is useful. /// Uses Redis pubsub. - Status PublishActor(const ActorID &id, const rpc::ActorTableData &message, + Status PublishActor(const ActorID &id, + const rpc::ActorTableData &message, const StatusCallback &done); /// Uses Redis pubsub. - Status PublishJob(const JobID &id, const rpc::JobTableData &message, + Status PublishJob(const JobID &id, + const rpc::JobTableData &message, const StatusCallback &done); /// Uses Redis pubsub. - Status PublishNodeInfo(const NodeID &id, const rpc::GcsNodeInfo &message, + Status PublishNodeInfo(const NodeID &id, + const rpc::GcsNodeInfo &message, const StatusCallback &done); /// Uses Redis pubsub. - Status PublishNodeResource(const NodeID &id, const rpc::NodeResourceChange &message, + Status PublishNodeResource(const NodeID &id, + const rpc::NodeResourceChange &message, const StatusCallback &done); /// Actually rpc::WorkerDeltaData is not a delta message. /// Uses Redis pubsub. - Status PublishWorkerFailure(const WorkerID &id, const rpc::WorkerDeltaData &message, + Status PublishWorkerFailure(const WorkerID &id, + const rpc::WorkerDeltaData &message, const StatusCallback &done); /// Uses Redis pubsub. - Status PublishError(const std::string &id, const rpc::ErrorTableData &message, + Status PublishError(const std::string &id, + const rpc::ErrorTableData &message, const StatusCallback &done); /// TODO: remove once it is converted to GRPC-based push broadcasting. diff --git a/src/ray/gcs/pubsub/test/gcs_pub_sub_test.cc b/src/ray/gcs/pubsub/test/gcs_pub_sub_test.cc index 239c86353..dd8b4f3ba 100644 --- a/src/ray/gcs/pubsub/test/gcs_pub_sub_test.cc +++ b/src/ray/gcs/pubsub/test/gcs_pub_sub_test.cc @@ -56,7 +56,8 @@ class GcsPubSubTest : public ::testing::Test { client_.reset(); } - void Subscribe(const std::string &channel, const std::string &id, + void Subscribe(const std::string &channel, + const std::string &id, std::vector &result) { std::promise promise; auto done = [&promise](const Status &status) { promise.set_value(status.ok()); }; @@ -84,7 +85,8 @@ class GcsPubSubTest : public ::testing::Test { RAY_CHECK_OK(pub_sub_->Unsubscribe(channel, id)); } - bool Publish(const std::string &channel, const std::string &id, + bool Publish(const std::string &channel, + const std::string &id, const std::string &data) { std::promise promise; auto done = [&promise](const Status &status) { promise.set_value(status.ok()); }; diff --git a/src/ray/gcs/redis_async_context.cc b/src/ray/gcs/redis_async_context.cc index 42662cf5b..15e20a24c 100644 --- a/src/ray/gcs/redis_async_context.cc +++ b/src/ray/gcs/redis_async_context.cc @@ -65,8 +65,10 @@ void RedisAsyncContext::RedisAsyncHandleWrite() { redisAsyncHandleWrite(redis_async_context_); } -Status RedisAsyncContext::RedisAsyncCommand(redisCallbackFn *fn, void *privdata, - const char *format, ...) { +Status RedisAsyncContext::RedisAsyncCommand(redisCallbackFn *fn, + void *privdata, + const char *format, + ...) { va_list ap; va_start(ap, format); @@ -89,8 +91,10 @@ Status RedisAsyncContext::RedisAsyncCommand(redisCallbackFn *fn, void *privdata, return Status::OK(); } -Status RedisAsyncContext::RedisAsyncCommandArgv(redisCallbackFn *fn, void *privdata, - int argc, const char **argv, +Status RedisAsyncContext::RedisAsyncCommandArgv(redisCallbackFn *fn, + void *privdata, + int argc, + const char **argv, const size_t *argvlen) { int ret_code = 0; { diff --git a/src/ray/gcs/redis_async_context.h b/src/ray/gcs/redis_async_context.h index 5d0a95b74..4b5880c78 100644 --- a/src/ray/gcs/redis_async_context.h +++ b/src/ray/gcs/redis_async_context.h @@ -71,8 +71,11 @@ class RedisAsyncContext { /// \param argv Array with arguments. /// \param argvlen Array with each argument's length. /// \return Status - Status RedisAsyncCommandArgv(redisCallbackFn *fn, void *privdata, int argc, - const char **argv, const size_t *argvlen); + Status RedisAsyncCommandArgv(redisCallbackFn *fn, + void *privdata, + int argc, + const char **argv, + const size_t *argvlen); private: /// This mutex is used to protect `redis_async_context`. diff --git a/src/ray/gcs/redis_client.cc b/src/ray/gcs/redis_client.cc index 0cefc3812..9b0bed838 100644 --- a/src/ray/gcs/redis_client.cc +++ b/src/ray/gcs/redis_client.cc @@ -28,7 +28,9 @@ namespace gcs { /// Run redis command using specified context and store the result in `reply`. Return true /// if the number of attemps didn't reach `redis_db_connect_retries`. static bool RunRedisCommandWithRetries( - redisContext *context, const char *command, redisReply **reply, + redisContext *context, + const char *command, + redisReply **reply, const std::function &condition) { int num_attempts = 0; while (num_attempts < RayConfig::instance().redis_db_connect_retries()) { @@ -62,7 +64,8 @@ static int DoGetNextJobID(redisContext *context) { return counter; } -static void GetRedisShards(redisContext *context, std::vector *addresses, +static void GetRedisShards(redisContext *context, + std::vector *addresses, std::vector *ports) { // Get the total number of Redis shards in the system. redisReply *reply = nullptr; @@ -80,7 +83,9 @@ static void GetRedisShards(redisContext *context, std::vector *addr // Get the addresses of all of the Redis shards. under_retry_limit = RunRedisCommandWithRetries( - context, "LRANGE RedisShards 0 -1", &reply, + context, + "LRANGE RedisShards 0 -1", + &reply, [&num_redis_shards](const redisReply *reply) { return static_cast(reply->elements) == num_redis_shards; }); @@ -123,11 +128,13 @@ Status RedisClient::Connect(std::vector io_services) primary_context_ = std::make_shared(*io_services[0]); - RAY_CHECK_OK(primary_context_->Connect( - options_.server_ip_, options_.server_port_, - /*sharding=*/true, - /*password=*/options_.password_, options_.enable_sync_conn_, - options_.enable_async_conn_, options_.enable_subscribe_conn_)); + RAY_CHECK_OK(primary_context_->Connect(options_.server_ip_, + options_.server_port_, + /*sharding=*/true, + /*password=*/options_.password_, + options_.enable_sync_conn_, + options_.enable_async_conn_, + options_.enable_subscribe_conn_)); if (options_.enable_sharding_conn_) { // Moving sharding into constructor defaultly means that sharding = true. @@ -147,19 +154,24 @@ Status RedisClient::Connect(std::vector io_services) // Populate shard_contexts. shard_contexts_.push_back(std::make_shared(io_service)); // Only async context is used in sharding context, so wen disable the other two. - RAY_CHECK_OK(shard_contexts_[i]->Connect( - addresses[i], ports[i], /*sharding=*/true, - /*password=*/options_.password_, /*enable_sync_conn=*/false, - /*enable_async_conn=*/true, /*enable_subscribe_conn=*/false)); + RAY_CHECK_OK(shard_contexts_[i]->Connect(addresses[i], + ports[i], + /*sharding=*/true, + /*password=*/options_.password_, + /*enable_sync_conn=*/false, + /*enable_async_conn=*/true, + /*enable_subscribe_conn=*/false)); } } else { shard_contexts_.push_back(std::make_shared(*io_services[0])); // Only async context is used in sharding context, so wen disable the other two. - RAY_CHECK_OK(shard_contexts_[0]->Connect( - options_.server_ip_, options_.server_port_, - /*sharding=*/true, - /*password=*/options_.password_, /*enable_sync_conn=*/false, - /*enable_async_conn=*/true, /*enable_subscribe_conn=*/false)); + RAY_CHECK_OK(shard_contexts_[0]->Connect(options_.server_ip_, + options_.server_port_, + /*sharding=*/true, + /*password=*/options_.password_, + /*enable_sync_conn=*/false, + /*enable_async_conn=*/true, + /*enable_subscribe_conn=*/false)); } Attach(); diff --git a/src/ray/gcs/redis_client.h b/src/ray/gcs/redis_client.h index 4d061d121..6c506bc10 100644 --- a/src/ray/gcs/redis_client.h +++ b/src/ray/gcs/redis_client.h @@ -30,9 +30,13 @@ class RedisContext; class RedisClientOptions { public: - RedisClientOptions(const std::string &ip, int port, const std::string &password, - bool enable_sharding_conn = true, bool enable_sync_conn = true, - bool enable_async_conn = true, bool enable_subscribe_conn = true) + RedisClientOptions(const std::string &ip, + int port, + const std::string &password, + bool enable_sharding_conn = true, + bool enable_sync_conn = true, + bool enable_async_conn = true, + bool enable_subscribe_conn = true) : server_ip_(ip), server_port_(port), password_(password), diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index abfab69f6..e005336d8 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -304,9 +304,11 @@ void FreeRedisContext(redisAsyncContext *context) {} void FreeRedisContext(RedisAsyncContext *context) {} template -Status ConnectWithoutRetries(const std::string &address, int port, +Status ConnectWithoutRetries(const std::string &address, + int port, const RedisConnectFunction &connect_function, - RedisContext **context, std::string &errorMessage) { + RedisContext **context, + std::string &errorMessage) { // This currently returns the errorMessage in two different ways, // as an output parameter and in the Status::RedisError, // because we're not sure whether we'll want to change what this returns. @@ -331,7 +333,8 @@ Status ConnectWithoutRetries(const std::string &address, int port, } template -Status ConnectWithRetries(const std::string &address, int port, +Status ConnectWithRetries(const std::string &address, + int port, const RedisConnectFunction &connect_function, RedisContext **context) { int connection_attempts = 0; @@ -360,13 +363,17 @@ Status ConnectWithRetries(const std::string &address, int port, Status RedisContext::PingPort(const std::string &address, int port) { std::string errorMessage; - return ConnectWithoutRetries(address, port, redisConnect, - static_cast(nullptr), errorMessage); + return ConnectWithoutRetries( + address, port, redisConnect, static_cast(nullptr), errorMessage); } -Status RedisContext::Connect(const std::string &address, int port, bool sharding, - const std::string &password, bool enable_sync_conn, - bool enable_async_conn, bool enable_subscribe_conn) { +Status RedisContext::Connect(const std::string &address, + int port, + bool sharding, + const std::string &password, + bool enable_sync_conn, + bool enable_async_conn, + bool enable_subscribe_conn) { RAY_CHECK(!context_); RAY_CHECK(!redis_async_context_); RAY_CHECK(!async_redis_subscribe_context_); @@ -439,7 +446,10 @@ Status RedisContext::RunArgvAsync(const std::vector &args, // Run the Redis command. Status status = redis_async_context_->RedisAsyncCommandArgv( reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), args.size(), argv.data(), argc.data()); + reinterpret_cast(callback_index), + args.size(), + argv.data(), + argc.data()); return status; } @@ -461,14 +471,19 @@ Status RedisContext::SubscribeAsync(const NodeID &node_id, std::string redis_command = "SUBSCRIBE %d"; status = async_redis_subscribe_context_->RedisAsyncCommand( reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), pubsub_channel); + reinterpret_cast(callback_index), + redis_command.c_str(), + pubsub_channel); } else { // Subscribe only to messages sent to this client. std::string redis_command = "SUBSCRIBE %d:%b"; status = async_redis_subscribe_context_->RedisAsyncCommand( reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), pubsub_channel, - node_id.Data(), node_id.Size()); + reinterpret_cast(callback_index), + redis_command.c_str(), + pubsub_channel, + node_id.Data(), + node_id.Size()); } return status; @@ -479,12 +494,14 @@ Status RedisContext::SubscribeAsync(const std::string &channel, int64_t callback_index) { RAY_CHECK(async_redis_subscribe_context_); - RAY_UNUSED(RedisCallbackManager::instance().AddCallback(redisCallback, true, - io_service_, callback_index)); + RAY_UNUSED(RedisCallbackManager::instance().AddCallback( + redisCallback, true, io_service_, callback_index)); std::string redis_command = "SUBSCRIBE %b"; return async_redis_subscribe_context_->RedisAsyncCommand( reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), channel.c_str(), + reinterpret_cast(callback_index), + redis_command.c_str(), + channel.c_str(), channel.size()); } @@ -494,7 +511,9 @@ Status RedisContext::UnsubscribeAsync(const std::string &channel) { std::string redis_command = "UNSUBSCRIBE %b"; return async_redis_subscribe_context_->RedisAsyncCommand( reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(-1), redis_command.c_str(), channel.c_str(), + reinterpret_cast(-1), + redis_command.c_str(), + channel.c_str(), channel.size()); } @@ -503,12 +522,14 @@ Status RedisContext::PSubscribeAsync(const std::string &pattern, int64_t callback_index) { RAY_CHECK(async_redis_subscribe_context_); - RAY_UNUSED(RedisCallbackManager::instance().AddCallback(redisCallback, true, - io_service_, callback_index)); + RAY_UNUSED(RedisCallbackManager::instance().AddCallback( + redisCallback, true, io_service_, callback_index)); std::string redis_command = "PSUBSCRIBE %b"; return async_redis_subscribe_context_->RedisAsyncCommand( reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), pattern.c_str(), + reinterpret_cast(callback_index), + redis_command.c_str(), + pattern.c_str(), pattern.size()); } @@ -518,11 +539,14 @@ Status RedisContext::PUnsubscribeAsync(const std::string &pattern) { std::string redis_command = "PUNSUBSCRIBE %b"; return async_redis_subscribe_context_->RedisAsyncCommand( reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(-1), redis_command.c_str(), pattern.c_str(), + reinterpret_cast(-1), + redis_command.c_str(), + pattern.c_str(), pattern.size()); } -Status RedisContext::PublishAsync(const std::string &channel, const std::string &message, +Status RedisContext::PublishAsync(const std::string &channel, + const std::string &message, const RedisCallback &redisCallback) { std::vector args = {"PUBLISH", channel, message}; return RunArgvAsync(args, redisCallback); diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index a41597ece..33e8bc483 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -120,7 +120,9 @@ class RedisCallbackManager { struct CallbackItem : public std::enable_shared_from_this { CallbackItem() = default; - CallbackItem(const RedisCallback &callback, bool is_subscription, int64_t start_time, + CallbackItem(const RedisCallback &callback, + bool is_subscription, + int64_t start_time, instrumented_io_context &io_service) : callback_(callback), is_subscription_(is_subscription), @@ -145,8 +147,10 @@ class RedisCallbackManager { int64_t AllocateCallbackIndex(); /// Add a callback at an optionally specified index. - int64_t AddCallback(const RedisCallback &function, bool is_subscription, - instrumented_io_context &io_service, int64_t callback_index = -1); + int64_t AddCallback(const RedisCallback &function, + bool is_subscription, + instrumented_io_context &io_service, + int64_t callback_index = -1); /// Remove a callback. void RemoveCallback(int64_t callback_index); @@ -179,9 +183,13 @@ class RedisContext { /// \return The Status that we would get if we Connected. Status PingPort(const std::string &address, int port); - Status Connect(const std::string &address, int port, bool sharding, - const std::string &password, bool enable_sync_conn = true, - bool enable_async_conn = true, bool enable_subscribe_conn = true); + Status Connect(const std::string &address, + int port, + bool sharding, + const std::string &password, + bool enable_sync_conn = true, + bool enable_async_conn = true, + bool enable_subscribe_conn = true); /// Run an arbitrary Redis command synchronously. /// @@ -204,8 +212,10 @@ class RedisContext { /// \param redisCallback The callback function that the notification calls. /// \param out_callback_index The output pointer to callback index. /// \return Status. - Status SubscribeAsync(const NodeID &node_id, const TablePubsub pubsub_channel, - const RedisCallback &redisCallback, int64_t *out_callback_index); + Status SubscribeAsync(const NodeID &node_id, + const TablePubsub pubsub_channel, + const RedisCallback &redisCallback, + int64_t *out_callback_index); /// Subscribes the client to the given channel. /// @@ -215,7 +225,8 @@ class RedisContext { /// must already be allocated in the callback manager via /// RedisCallbackManager::AllocateCallbackIndex. /// \return Status. - Status SubscribeAsync(const std::string &channel, const RedisCallback &redisCallback, + Status SubscribeAsync(const std::string &channel, + const RedisCallback &redisCallback, int64_t callback_index); /// Unsubscribes the client from the given channel. @@ -232,7 +243,8 @@ class RedisContext { /// must already be allocated in the callback manager via /// RedisCallbackManager::AllocateCallbackIndex. /// \return Status. - Status PSubscribeAsync(const std::string &pattern, const RedisCallback &redisCallback, + Status PSubscribeAsync(const std::string &pattern, + const RedisCallback &redisCallback, int64_t callback_index); /// Unsubscribes the client from the given pattern. @@ -247,7 +259,8 @@ class RedisContext { /// \param message The message to be published to redis. /// \param redisCallback The callback will be called when the message is published to /// redis. \return Status. - Status PublishAsync(const std::string &channel, const std::string &message, + Status PublishAsync(const std::string &channel, + const std::string &message, const RedisCallback &redisCallback); redisContext *sync_context() { diff --git a/src/ray/gcs/store_client/in_memory_store_client.cc b/src/ray/gcs/store_client/in_memory_store_client.cc index 04835b0fb..a99785bc2 100644 --- a/src/ray/gcs/store_client/in_memory_store_client.cc +++ b/src/ray/gcs/store_client/in_memory_store_client.cc @@ -19,7 +19,8 @@ namespace ray { namespace gcs { Status InMemoryStoreClient::AsyncPut(const std::string &table_name, - const std::string &key, const std::string &data, + const std::string &key, + const std::string &data, const StatusCallback &callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); @@ -136,8 +137,10 @@ Status InMemoryStoreClient::AsyncBatchDelete(const std::string &table_name, } Status InMemoryStoreClient::AsyncBatchDeleteWithIndex( - const std::string &table_name, const std::vector &keys, - const std::vector &index_keys, const StatusCallback &callback) { + const std::string &table_name, + const std::vector &keys, + const std::vector &index_keys, + const StatusCallback &callback) { RAY_CHECK(keys.size() == index_keys.size()); auto table = GetOrCreateTable(table_name); @@ -169,7 +172,8 @@ Status InMemoryStoreClient::AsyncBatchDeleteWithIndex( } Status InMemoryStoreClient::AsyncGetByIndex( - const std::string &table_name, const std::string &index_key, + const std::string &table_name, + const std::string &index_key, const MapCallback &callback) { RAY_CHECK(callback); auto table = GetOrCreateTable(table_name); diff --git a/src/ray/gcs/store_client/in_memory_store_client.h b/src/ray/gcs/store_client/in_memory_store_client.h index d30797148..60af17a2b 100644 --- a/src/ray/gcs/store_client/in_memory_store_client.h +++ b/src/ray/gcs/store_client/in_memory_store_client.h @@ -32,26 +32,34 @@ class InMemoryStoreClient : public StoreClient { explicit InMemoryStoreClient(instrumented_io_context &main_io_service) : main_io_service_(main_io_service) {} - Status AsyncPut(const std::string &table_name, const std::string &key, - const std::string &data, const StatusCallback &callback) override; + Status AsyncPut(const std::string &table_name, + const std::string &key, + const std::string &data, + const StatusCallback &callback) override; - Status AsyncPutWithIndex(const std::string &table_name, const std::string &key, - const std::string &index_key, const std::string &data, + Status AsyncPutWithIndex(const std::string &table_name, + const std::string &key, + const std::string &index_key, + const std::string &data, const StatusCallback &callback) override; - Status AsyncGet(const std::string &table_name, const std::string &key, + Status AsyncGet(const std::string &table_name, + const std::string &key, const OptionalItemCallback &callback) override; - Status AsyncGetByIndex(const std::string &table_name, const std::string &index_key, + Status AsyncGetByIndex(const std::string &table_name, + const std::string &index_key, const MapCallback &callback) override; Status AsyncGetAll(const std::string &table_name, const MapCallback &callback) override; - Status AsyncDelete(const std::string &table_name, const std::string &key, + Status AsyncDelete(const std::string &table_name, + const std::string &key, const StatusCallback &callback) override; - Status AsyncDeleteWithIndex(const std::string &table_name, const std::string &key, + Status AsyncDeleteWithIndex(const std::string &table_name, + const std::string &key, const std::string &index_key, const StatusCallback &callback) override; @@ -64,7 +72,8 @@ class InMemoryStoreClient : public StoreClient { const std::vector &index_keys, const StatusCallback &callback) override; - Status AsyncDeleteByIndex(const std::string &table_name, const std::string &index_key, + Status AsyncDeleteByIndex(const std::string &table_name, + const std::string &index_key, const StatusCallback &callback) override; int GetNextJobID() override; diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index bfa08f0fe..2ca655bea 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -27,7 +27,8 @@ namespace gcs { std::string RedisStoreClient::table_separator_ = ":"; std::string RedisStoreClient::index_table_separator_ = "&"; -Status RedisStoreClient::AsyncPut(const std::string &table_name, const std::string &key, +Status RedisStoreClient::AsyncPut(const std::string &table_name, + const std::string &key, const std::string &data, const StatusCallback &callback) { return DoPut(GenRedisKey(table_name, key), data, callback); @@ -58,7 +59,8 @@ Status RedisStoreClient::AsyncPutWithIndex(const std::string &table_name, return status; } -Status RedisStoreClient::AsyncGet(const std::string &table_name, const std::string &key, +Status RedisStoreClient::AsyncGet(const std::string &table_name, + const std::string &key, const OptionalItemCallback &callback) { RAY_CHECK(callback != nullptr); @@ -135,8 +137,10 @@ Status RedisStoreClient::AsyncBatchDelete(const std::string &table_name, } Status RedisStoreClient::AsyncBatchDeleteWithIndex( - const std::string &table_name, const std::vector &keys, - const std::vector &index_keys, const StatusCallback &callback) { + const std::string &table_name, + const std::vector &keys, + const std::vector &index_keys, + const StatusCallback &callback) { RAY_CHECK(keys.size() == index_keys.size()); std::vector redis_keys; @@ -150,7 +154,8 @@ Status RedisStoreClient::AsyncBatchDeleteWithIndex( } Status RedisStoreClient::AsyncGetByIndex( - const std::string &table_name, const std::string &index_key, + const std::string &table_name, + const std::string &index_key, const MapCallback &callback) { RAY_CHECK(callback); std::string match_pattern = GenRedisMatchPattern(table_name, index_key); @@ -202,7 +207,8 @@ Status RedisStoreClient::AsyncDeleteByIndex(const std::string &table_name, return scanner->ScanKeys(match_pattern, on_done); } -Status RedisStoreClient::DoPut(const std::string &key, const std::string &data, +Status RedisStoreClient::DoPut(const std::string &key, + const std::string &data, const StatusCallback &callback) { std::vector args = {"SET", key, data}; RedisCallback write_callback = nullptr; @@ -229,8 +235,8 @@ Status RedisStoreClient::DeleteByKeys(const std::vector &keys, for (auto &command_list : del_commands_by_shards) { for (auto &command : command_list.second) { - auto delete_callback = [finished_count, total_count, - callback](const std::shared_ptr &reply) { + auto delete_callback = [finished_count, total_count, callback]( + const std::shared_ptr &reply) { ++(*finished_count); if (*finished_count == total_count) { if (callback) { @@ -247,7 +253,8 @@ Status RedisStoreClient::DeleteByKeys(const std::vector &keys, absl::flat_hash_map>> RedisStoreClient::GenCommandsByShards(const std::shared_ptr &redis_client, const std::string &command, - const std::vector &keys, int *count) { + const std::vector &keys, + int *count) { absl::flat_hash_map>> commands_by_shards; for (auto &key : keys) { @@ -317,7 +324,8 @@ std::string RedisStoreClient::GetKeyFromRedisKey(const std::string &redis_key, } Status RedisStoreClient::MGetValues( - std::shared_ptr redis_client, const std::string &table_name, + std::shared_ptr redis_client, + const std::string &table_name, const std::vector &keys, const MapCallback &callback) { // The `MGET` command for each shard. @@ -329,24 +337,26 @@ Status RedisStoreClient::MGetValues( for (auto &command_list : mget_commands_by_shards) { for (auto &command : command_list.second) { auto mget_keys = std::move(command); - auto mget_callback = [table_name, finished_count, total_count, mget_keys, callback, - key_value_map](const std::shared_ptr &reply) { - if (!reply->IsNil()) { - auto value = reply->ReadAsStringArray(); - // The 0 th element of mget_keys is "MGET", so we start from the 1 th element. - for (size_t index = 0; index < value.size(); ++index) { - if (value[index].has_value()) { - (*key_value_map)[GetKeyFromRedisKey(mget_keys[index + 1], table_name)] = - *(value[index]); + auto mget_callback = + [table_name, finished_count, total_count, mget_keys, callback, key_value_map]( + const std::shared_ptr &reply) { + if (!reply->IsNil()) { + auto value = reply->ReadAsStringArray(); + // The 0 th element of mget_keys is "MGET", so we start from the 1 th + // element. + for (size_t index = 0; index < value.size(); ++index) { + if (value[index].has_value()) { + (*key_value_map)[GetKeyFromRedisKey(mget_keys[index + 1], table_name)] = + *(value[index]); + } + } } - } - } - ++(*finished_count); - if (*finished_count == total_count) { - callback(std::move(*key_value_map)); - } - }; + ++(*finished_count); + if (*finished_count == total_count) { + callback(std::move(*key_value_map)); + } + }; RAY_CHECK_OK(command_list.first->RunArgvAsync(mget_keys, mget_callback)); } } @@ -404,14 +414,17 @@ void RedisStoreClient::RedisScanner::Scan(const std::string &match_pattern, size_t shard_index = item.first; size_t cursor = item.second; - auto scan_callback = [this, match_pattern, shard_index, - callback](const std::shared_ptr &reply) { + auto scan_callback = [this, match_pattern, shard_index, callback]( + const std::shared_ptr &reply) { OnScanCallback(match_pattern, shard_index, reply, callback); }; // Scan by prefix from Redis. - std::vector args = {"SCAN", std::to_string(cursor), - "MATCH", match_pattern, - "COUNT", std::to_string(batch_count)}; + std::vector args = {"SCAN", + std::to_string(cursor), + "MATCH", + match_pattern, + "COUNT", + std::to_string(batch_count)}; auto shard_context = redis_client_->GetShardContexts()[shard_index]; Status status = shard_context->RunArgvAsync(args, scan_callback); if (!status.ok()) { @@ -421,8 +434,10 @@ void RedisStoreClient::RedisScanner::Scan(const std::string &match_pattern, } void RedisStoreClient::RedisScanner::OnScanCallback( - const std::string &match_pattern, size_t shard_index, - const std::shared_ptr &reply, const StatusCallback &callback) { + const std::string &match_pattern, + size_t shard_index, + const std::shared_ptr &reply, + const StatusCallback &callback) { RAY_CHECK(reply); std::vector scan_result; size_t cursor = reply->ReadAsScanArray(&scan_result); diff --git a/src/ray/gcs/store_client/redis_store_client.h b/src/ray/gcs/store_client/redis_store_client.h index a762bae62..8f86f35ec 100644 --- a/src/ray/gcs/store_client/redis_store_client.h +++ b/src/ray/gcs/store_client/redis_store_client.h @@ -29,26 +29,34 @@ class RedisStoreClient : public StoreClient { explicit RedisStoreClient(std::shared_ptr redis_client) : redis_client_(std::move(redis_client)) {} - Status AsyncPut(const std::string &table_name, const std::string &key, - const std::string &data, const StatusCallback &callback) override; + Status AsyncPut(const std::string &table_name, + const std::string &key, + const std::string &data, + const StatusCallback &callback) override; - Status AsyncPutWithIndex(const std::string &table_name, const std::string &key, - const std::string &index_key, const std::string &data, + Status AsyncPutWithIndex(const std::string &table_name, + const std::string &key, + const std::string &index_key, + const std::string &data, const StatusCallback &callback) override; - Status AsyncGet(const std::string &table_name, const std::string &key, + Status AsyncGet(const std::string &table_name, + const std::string &key, const OptionalItemCallback &callback) override; - Status AsyncGetByIndex(const std::string &table_name, const std::string &index_key, + Status AsyncGetByIndex(const std::string &table_name, + const std::string &index_key, const MapCallback &callback) override; Status AsyncGetAll(const std::string &table_name, const MapCallback &callback) override; - Status AsyncDelete(const std::string &table_name, const std::string &key, + Status AsyncDelete(const std::string &table_name, + const std::string &key, const StatusCallback &callback) override; - Status AsyncDeleteWithIndex(const std::string &table_name, const std::string &key, + Status AsyncDeleteWithIndex(const std::string &table_name, + const std::string &key, const std::string &index_key, const StatusCallback &callback) override; @@ -61,7 +69,8 @@ class RedisStoreClient : public StoreClient { const std::vector &index_keys, const StatusCallback &callback) override; - Status AsyncDeleteByIndex(const std::string &table_name, const std::string &index_key, + Status AsyncDeleteByIndex(const std::string &table_name, + const std::string &index_key, const StatusCallback &callback) override; int GetNextJobID() override; @@ -86,7 +95,8 @@ class RedisStoreClient : public StoreClient { private: void Scan(const std::string &match_pattern, const StatusCallback &callback); - void OnScanCallback(const std::string &match_pattern, size_t shard_index, + void OnScanCallback(const std::string &match_pattern, + size_t shard_index, const std::shared_ptr &reply, const StatusCallback &callback); @@ -108,7 +118,8 @@ class RedisStoreClient : public StoreClient { std::shared_ptr redis_client_; }; - Status DoPut(const std::string &key, const std::string &data, + Status DoPut(const std::string &key, + const std::string &data, const StatusCallback &callback); Status DeleteByKeys(const std::vector &keys, @@ -118,7 +129,8 @@ class RedisStoreClient : public StoreClient { /// operations. static absl::flat_hash_map>> GenCommandsByShards(const std::shared_ptr &redis_client, - const std::string &command, const std::vector &keys, + const std::string &command, + const std::vector &keys, int *count); /// The separator is used when building redis key. @@ -127,7 +139,8 @@ class RedisStoreClient : public StoreClient { static std::string GenRedisKey(const std::string &table_name, const std::string &key); - static std::string GenRedisKey(const std::string &table_name, const std::string &key, + static std::string GenRedisKey(const std::string &table_name, + const std::string &key, const std::string &index_key); static std::string GenRedisMatchPattern(const std::string &table_name); diff --git a/src/ray/gcs/store_client/store_client.h b/src/ray/gcs/store_client/store_client.h index 67c5e7622..809e26c36 100644 --- a/src/ray/gcs/store_client/store_client.h +++ b/src/ray/gcs/store_client/store_client.h @@ -41,8 +41,10 @@ class StoreClient { /// \param data The value of the key that will be written to the table. /// \param callback Callback that will be called after write finishes. /// \return Status - virtual Status AsyncPut(const std::string &table_name, const std::string &key, - const std::string &data, const StatusCallback &callback) = 0; + virtual Status AsyncPut(const std::string &table_name, + const std::string &key, + const std::string &data, + const StatusCallback &callback) = 0; /// Write data to the given table asynchronously. /// @@ -52,8 +54,10 @@ class StoreClient { /// \param data The value of the key that will be written to the table. /// \param callback Callback that will be called after write finishes. /// \return Status - virtual Status AsyncPutWithIndex(const std::string &table_name, const std::string &key, - const std::string &index_key, const std::string &data, + virtual Status AsyncPutWithIndex(const std::string &table_name, + const std::string &key, + const std::string &index_key, + const std::string &data, const StatusCallback &callback) = 0; /// Get data from the given table asynchronously. @@ -62,7 +66,8 @@ class StoreClient { /// \param key The key to lookup from the table. /// \param callback Callback that will be called after read finishes. /// \return Status - virtual Status AsyncGet(const std::string &table_name, const std::string &key, + virtual Status AsyncGet(const std::string &table_name, + const std::string &key, const OptionalItemCallback &callback) = 0; /// Get data by index from the given table asynchronously. @@ -72,7 +77,8 @@ class StoreClient { /// \param callback Callback that will be called after read finishes. /// \return Status virtual Status AsyncGetByIndex( - const std::string &table_name, const std::string &index_key, + const std::string &table_name, + const std::string &index_key, const MapCallback &callback) = 0; /// Get all data from the given table asynchronously. @@ -89,7 +95,8 @@ class StoreClient { /// \param key The key that will be deleted from the table. /// \param callback Callback that will be called after delete finishes. /// \return Status - virtual Status AsyncDelete(const std::string &table_name, const std::string &key, + virtual Status AsyncDelete(const std::string &table_name, + const std::string &key, const StatusCallback &callback) = 0; /// Delete data from the given table asynchronously, this can delete diff --git a/src/ray/gcs/store_client/test/redis_store_client_test.cc b/src/ray/gcs/store_client/test/redis_store_client_test.cc index eb62031bd..8499d3774 100644 --- a/src/ray/gcs/store_client/test/redis_store_client_test.cc +++ b/src/ray/gcs/store_client/test/redis_store_client_test.cc @@ -33,7 +33,9 @@ class RedisStoreClientTest : public StoreClientTestBase { static void TearDownTestCase() { TestSetupUtil::ShutDownRedisServers(); } void InitStoreClient() override { - RedisClientOptions options("127.0.0.1", TEST_REDIS_SERVER_PORTS.front(), "", + RedisClientOptions options("127.0.0.1", + TEST_REDIS_SERVER_PORTS.front(), + "", /*enable_sharding_conn=*/false); redis_client_ = std::make_shared(options); RAY_CHECK_OK(redis_client_->Connect(io_service_pool_->GetAll())); @@ -69,7 +71,8 @@ TEST_F(RedisStoreClientTest, TestAsyncBatchDeleteWithIndex) { int main(int argc, char **argv) { InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLog::ShutDownRayLog, + argv[0], ray::RayLogLevel::INFO, /*log_dir=*/""); ::testing::InitGoogleTest(&argc, argv); diff --git a/src/ray/gcs/store_client/test/store_client_test_base.h b/src/ray/gcs/store_client/test/store_client_test_base.h index bfcbd042a..ad385706d 100644 --- a/src/ray/gcs/store_client/test/store_client_test_base.h +++ b/src/ray/gcs/store_client/test/store_client_test_base.h @@ -67,11 +67,11 @@ class StoreClientTestBase : public ::testing::Test { }; for (const auto &[key, value] : key_to_value_) { ++pending_count_; - RAY_CHECK_OK(store_client_->AsyncPut(table_name_, key.Binary(), - value.SerializeAsString(), put_calllback)); + RAY_CHECK_OK(store_client_->AsyncPut( + table_name_, key.Binary(), value.SerializeAsString(), put_calllback)); // Make sure no-op callback is handled well - RAY_CHECK_OK(store_client_->AsyncPut(table_name_, key.Binary(), - value.SerializeAsString(), nullptr)); + RAY_CHECK_OK(store_client_->AsyncPut( + table_name_, key.Binary(), value.SerializeAsString(), nullptr)); } WaitPendingDone(); } @@ -130,13 +130,17 @@ class StoreClientTestBase : public ::testing::Test { auto put_calllback = [this](const Status &status) { --pending_count_; }; for (const auto &[key, value] : key_to_value_) { ++pending_count_; - RAY_CHECK_OK(store_client_->AsyncPutWithIndex( - table_name_, key.Binary(), key_to_index_[key].Hex(), value.SerializeAsString(), - put_calllback)); - // Make sure no-op callback is handled well - RAY_CHECK_OK(store_client_->AsyncPutWithIndex(table_name_, key.Binary(), + RAY_CHECK_OK(store_client_->AsyncPutWithIndex(table_name_, + key.Binary(), key_to_index_[key].Hex(), - value.SerializeAsString(), nullptr)); + value.SerializeAsString(), + put_calllback)); + // Make sure no-op callback is handled well + RAY_CHECK_OK(store_client_->AsyncPutWithIndex(table_name_, + key.Binary(), + key_to_index_[key].Hex(), + value.SerializeAsString(), + nullptr)); } WaitPendingDone(); } @@ -240,8 +244,8 @@ class StoreClientTestBase : public ::testing::Test { keys.push_back(key.Binary()); index_keys.push_back(key_to_index_[key].Hex()); } - RAY_CHECK_OK(store_client_->AsyncBatchDeleteWithIndex(table_name_, keys, index_keys, - delete_calllback)); + RAY_CHECK_OK(store_client_->AsyncBatchDeleteWithIndex( + table_name_, keys, index_keys, delete_calllback)); // Make sure no-op callback is handled well RAY_CHECK_OK( store_client_->AsyncBatchDeleteWithIndex(table_name_, keys, index_keys, nullptr)); diff --git a/src/ray/gcs/test/asio_test.cc b/src/ray/gcs/test/asio_test.cc index 4c511203e..6df76b149 100644 --- a/src/ray/gcs/test/asio_test.cc +++ b/src/ray/gcs/test/asio_test.cc @@ -78,7 +78,8 @@ TEST_F(RedisAsioTest, TestRedisCommands) { ->PingPort(std::string("127.0.0.1"), TEST_REDIS_SERVER_PORTS.front() + 987) .ok()); ASSERT_TRUE(shard_context - ->Connect(std::string("127.0.0.1"), TEST_REDIS_SERVER_PORTS.front(), + ->Connect(std::string("127.0.0.1"), + TEST_REDIS_SERVER_PORTS.front(), /*sharding=*/true, /*password=*/std::string()) .ok()); @@ -92,7 +93,8 @@ TEST_F(RedisAsioTest, TestRedisCommands) { int main(int argc, char **argv) { InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLog::ShutDownRayLog, + argv[0], ray::RayLogLevel::INFO, /*log_dir=*/""); ::testing::InitGoogleTest(&argc, argv); diff --git a/src/ray/gcs/test/callback_reply_test.cc b/src/ray/gcs/test/callback_reply_test.cc index 42baff3f9..cd063e21a 100644 --- a/src/ray/gcs/test/callback_reply_test.cc +++ b/src/ray/gcs/test/callback_reply_test.cc @@ -12,9 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ray/gcs/redis_context.h" - #include "gtest/gtest.h" +#include "ray/gcs/redis_context.h" extern "C" { #include "hiredis/hiredis.h" @@ -70,10 +69,11 @@ TEST(TestCallbackReply, TestParseAsStringArray) { redis_reply_array_elements[2] = &redis_reply_nil2; redis_reply_array.element = redis_reply_array_elements; CallbackReply callback_reply(&redis_reply_array); - ASSERT_EQ(callback_reply.ReadAsStringArray(), - (std::vector>{ - std::optional(), std::optional(string1), - std::optional()})); + ASSERT_EQ( + callback_reply.ReadAsStringArray(), + (std::vector>{std::optional(), + std::optional(string1), + std::optional()})); } } } // namespace ray::gcs diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index b1d9d6c5c..b38d14e91 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -30,8 +30,12 @@ namespace ray { struct Mocker { static TaskSpecification GenActorCreationTask( - const JobID &job_id, int max_restarts, bool detached, const std::string &name, - const std::string &ray_namespace, const rpc::Address &owner_address, + const JobID &job_id, + int max_restarts, + bool detached, + const std::string &name, + const std::string &ray_namespace, + const rpc::Address &owner_address, std::unordered_map required_resources = std::unordered_map(), std::unordered_map required_placement_resources = @@ -41,21 +45,41 @@ struct Mocker { auto task_id = TaskID::ForActorCreationTask(actor_id); FunctionDescriptor function_descriptor; function_descriptor = FunctionDescriptorBuilder::BuildPython("", "", "", ""); - builder.SetCommonTaskSpec(task_id, name + ":" + function_descriptor->CallString(), - Language::PYTHON, function_descriptor, job_id, - TaskID::Nil(), 0, TaskID::Nil(), owner_address, 1, - required_resources, required_placement_resources, "", 0); + builder.SetCommonTaskSpec(task_id, + name + ":" + function_descriptor->CallString(), + Language::PYTHON, + function_descriptor, + job_id, + TaskID::Nil(), + 0, + TaskID::Nil(), + owner_address, + 1, + required_resources, + required_placement_resources, + "", + 0); rpc::SchedulingStrategy scheduling_strategy; scheduling_strategy.mutable_default_scheduling_strategy(); - builder.SetActorCreationTaskSpec(actor_id, {}, scheduling_strategy, max_restarts, - /*max_task_retries=*/0, {}, 1, detached, name, + builder.SetActorCreationTaskSpec(actor_id, + {}, + scheduling_strategy, + max_restarts, + /*max_task_retries=*/0, + {}, + 1, + detached, + name, ray_namespace); return builder.Build(); } static rpc::CreateActorRequest GenCreateActorRequest( - const JobID &job_id, int max_restarts = 0, bool detached = false, - const std::string &name = "", const std::string &ray_namespace = "") { + const JobID &job_id, + int max_restarts = 0, + bool detached = false, + const std::string &name = "", + const std::string &ray_namespace = "") { rpc::Address owner_address; owner_address.set_raylet_id(NodeID::FromRandom().Binary()); owner_address.set_ip_address("1234"); @@ -69,8 +93,11 @@ struct Mocker { } static rpc::RegisterActorRequest GenRegisterActorRequest( - const JobID &job_id, int max_restarts = 0, bool detached = false, - const std::string &name = "", const std::string &ray_namespace = "") { + const JobID &job_id, + int max_restarts = 0, + bool detached = false, + const std::string &name = "", + const std::string &ray_namespace = "") { rpc::Address owner_address; owner_address.set_raylet_id(NodeID::FromRandom().Binary()); owner_address.set_ip_address("1234"); @@ -85,7 +112,8 @@ struct Mocker { static std::vector> GenBundleSpecifications( const PlacementGroupID &placement_group_id, - absl::flat_hash_map &unit_resource, int bundles_size = 1) { + absl::flat_hash_map &unit_resource, + int bundles_size = 1) { std::vector> bundle_specs; for (int i = 0; i < bundles_size; i++) { rpc::Bundle bundle; @@ -104,7 +132,8 @@ struct Mocker { // TODO(@clay4444): Remove this once we did the batch rpc request refactor. static BundleSpecification GenBundleCreation( - const PlacementGroupID &placement_group_id, const int bundle_index, + const PlacementGroupID &placement_group_id, + const int bundle_index, absl::flat_hash_map &unit_resource) { rpc::Bundle bundle; auto mutable_bundle_id = bundle.mutable_bundle_id(); @@ -120,12 +149,19 @@ struct Mocker { static PlacementGroupSpecification GenPlacementGroupCreation( const std::string &name, std::vector> &bundles, - rpc::PlacementStrategy strategy, const JobID &job_id, const ActorID &actor_id) { + rpc::PlacementStrategy strategy, + const JobID &job_id, + const ActorID &actor_id) { PlacementGroupSpecBuilder builder; auto placement_group_id = PlacementGroupID::FromRandom(); - builder.SetPlacementGroupSpec(placement_group_id, name, bundles, strategy, - /* is_detached */ false, job_id, actor_id, + builder.SetPlacementGroupSpec(placement_group_id, + name, + bundles, + strategy, + /* is_detached */ false, + job_id, + actor_id, /* is_creator_detached */ false); return builder.Build(); } @@ -133,7 +169,9 @@ struct Mocker { static rpc::CreatePlacementGroupRequest GenCreatePlacementGroupRequest( const std::string name = "", rpc::PlacementStrategy strategy = rpc::PlacementStrategy::SPREAD, - int bundles_count = 2, double cpu_num = 1.0, const JobID job_id = JobID::FromInt(1), + int bundles_count = 2, + double cpu_num = 1.0, + const JobID job_id = JobID::FromInt(1), const ActorID &actor_id = ActorID::Nil()) { rpc::CreatePlacementGroupRequest request; std::vector> bundles; @@ -198,7 +236,8 @@ struct Mocker { } static std::shared_ptr GenAddJobRequest( - const JobID &job_id, const std::string &ray_namespace, + const JobID &job_id, + const std::string &ray_namespace, uint32_t num_java_worker_per_process) { auto job_config_data = std::make_shared(); job_config_data->set_ray_namespace(ray_namespace); diff --git a/src/ray/internal/internal.cc b/src/ray/internal/internal.cc index 2f25ce33b..0d77b3744 100644 --- a/src/ray/internal/internal.cc +++ b/src/ray/internal/internal.cc @@ -24,7 +24,8 @@ using ray::core::TaskOptions; std::vector SendInternal(const ActorID &peer_actor_id, std::shared_ptr buffer, - RayFunction &function, int return_num) { + RayFunction &function, + int return_num) { std::unordered_map resources; std::string name = function.GetFunctionDescriptor()->DefaultTaskName(); TaskOptions options{name, return_num, resources}; diff --git a/src/ray/internal/internal.h b/src/ray/internal/internal.h index 661f46fe0..1bc9a9063 100644 --- a/src/ray/internal/internal.h +++ b/src/ray/internal/internal.h @@ -32,7 +32,8 @@ using ray::core::RayFunction; /// \param[out] return_ids return ids from SubmitActorTask. std::vector SendInternal(const ActorID &peer_actor_id, std::shared_ptr buffer, - RayFunction &function, int return_num); + RayFunction &function, + int return_num); const stats::TagKeyType TagRegister(const std::string tag_name); } // namespace internal diff --git a/src/ray/object_manager/chunk_object_reader.cc b/src/ray/object_manager/chunk_object_reader.cc index a706de89f..3d944e250 100644 --- a/src/ray/object_manager/chunk_object_reader.cc +++ b/src/ray/object_manager/chunk_object_reader.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/object_manager/chunk_object_reader.h" + #include "ray/util/logging.h" namespace ray { diff --git a/src/ray/object_manager/memory_object_reader.cc b/src/ray/object_manager/memory_object_reader.cc index bd1aeb512..4cf3b4d3e 100644 --- a/src/ray/object_manager/memory_object_reader.cc +++ b/src/ray/object_manager/memory_object_reader.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/object_manager/memory_object_reader.h" + #include namespace ray { @@ -30,7 +31,8 @@ uint64_t MemoryObjectReader::GetMetadataSize() const { const rpc::Address &MemoryObjectReader::GetOwnerAddress() const { return owner_address_; } -bool MemoryObjectReader::ReadFromDataSection(uint64_t offset, uint64_t size, +bool MemoryObjectReader::ReadFromDataSection(uint64_t offset, + uint64_t size, char *output) const { if (offset + size > GetDataSize()) { return false; @@ -39,7 +41,8 @@ bool MemoryObjectReader::ReadFromDataSection(uint64_t offset, uint64_t size, return true; } -bool MemoryObjectReader::ReadFromMetadataSection(uint64_t offset, uint64_t size, +bool MemoryObjectReader::ReadFromMetadataSection(uint64_t offset, + uint64_t size, char *output) const { if (offset + size > GetMetadataSize()) { return false; diff --git a/src/ray/object_manager/memory_object_reader.h b/src/ray/object_manager/memory_object_reader.h index 3b8273946..51346ebd5 100644 --- a/src/ray/object_manager/memory_object_reader.h +++ b/src/ray/object_manager/memory_object_reader.h @@ -32,7 +32,8 @@ class MemoryObjectReader : public IObjectReader { const rpc::Address &GetOwnerAddress() const override; bool ReadFromDataSection(uint64_t offset, uint64_t size, char *output) const override; - bool ReadFromMetadataSection(uint64_t offset, uint64_t size, + bool ReadFromMetadataSection(uint64_t offset, + uint64_t size, char *output) const override; private: diff --git a/src/ray/object_manager/object_buffer_pool.cc b/src/ray/object_manager/object_buffer_pool.cc index a91f64952..0a5ad72aa 100644 --- a/src/ray/object_manager/object_buffer_pool.cc +++ b/src/ray/object_manager/object_buffer_pool.cc @@ -96,11 +96,12 @@ ObjectBufferPool::CreateObjectReader(const ObjectID &object_id, ray::Status ObjectBufferPool::CreateChunk(const ObjectID &object_id, const rpc::Address &owner_address, - uint64_t data_size, uint64_t metadata_size, + uint64_t data_size, + uint64_t metadata_size, uint64_t chunk_index) { absl::MutexLock lock(&pool_mutex_); - RAY_RETURN_NOT_OK(EnsureBufferExists(object_id, owner_address, data_size, metadata_size, - chunk_index)); + RAY_RETURN_NOT_OK(EnsureBufferExists( + object_id, owner_address, data_size, metadata_size, chunk_index)); auto &state = create_buffer_state_.at(object_id); if (chunk_index >= state.chunk_state.size()) { return ray::Status::IOError("Object size mismatch"); @@ -113,8 +114,10 @@ ray::Status ObjectBufferPool::CreateChunk(const ObjectID &object_id, return ray::Status::OK(); } -void ObjectBufferPool::WriteChunk(const ObjectID &object_id, uint64_t data_size, - uint64_t metadata_size, const uint64_t chunk_index, +void ObjectBufferPool::WriteChunk(const ObjectID &object_id, + uint64_t data_size, + uint64_t metadata_size, + const uint64_t chunk_index, const std::string &data) { absl::MutexLock lock(&pool_mutex_); auto it = create_buffer_state_.find(object_id); @@ -162,7 +165,9 @@ void ObjectBufferPool::AbortCreateInternal(const ObjectID &object_id) { } std::vector ObjectBufferPool::BuildChunks( - const ObjectID &object_id, uint8_t *data, uint64_t data_size, + const ObjectID &object_id, + uint8_t *data, + uint64_t data_size, std::shared_ptr buffer_ref) { uint64_t space_remaining = data_size; std::vector chunks; @@ -173,8 +178,8 @@ std::vector ObjectBufferPool::BuildChunks( chunks.emplace_back(chunks.size(), data + position, space_remaining, buffer_ref); space_remaining = 0; } else { - chunks.emplace_back(chunks.size(), data + position, default_chunk_size_, - buffer_ref); + chunks.emplace_back( + chunks.size(), data + position, default_chunk_size_, buffer_ref); space_remaining -= default_chunk_size_; } } @@ -236,8 +241,12 @@ ray::Status ObjectBufferPool::EnsureBufferExists(const ObjectID &object_id, // Release pool_mutex_ during the blocking create call. pool_mutex_.Unlock(); Status s = store_client_->CreateAndSpillIfNeeded( - object_id, owner_address, static_cast(object_size), nullptr, - static_cast(metadata_size), &data, + object_id, + owner_address, + static_cast(object_size), + nullptr, + static_cast(metadata_size), + &data, plasma::flatbuf::ObjectSource::ReceivedFromRemoteRaylet); pool_mutex_.Lock(); @@ -263,8 +272,10 @@ ray::Status ObjectBufferPool::EnsureBufferExists(const ObjectID &object_id, uint8_t *mutable_data = data->Data(); uint64_t num_chunks = GetNumChunks(data_size); auto inserted = create_buffer_state_.emplace( - std::piecewise_construct, std::forward_as_tuple(object_id), - std::forward_as_tuple(metadata_size, data_size, + std::piecewise_construct, + std::forward_as_tuple(object_id), + std::forward_as_tuple(metadata_size, + data_size, BuildChunks(object_id, mutable_data, data_size, data))); RAY_CHECK(inserted.first->second.chunk_info.size() == num_chunks); RAY_LOG(DEBUG) << "Created object " << object_id diff --git a/src/ray/object_manager/object_buffer_pool.h b/src/ray/object_manager/object_buffer_pool.h index 21b7d69dc..e514b2d72 100644 --- a/src/ray/object_manager/object_buffer_pool.h +++ b/src/ray/object_manager/object_buffer_pool.h @@ -38,7 +38,9 @@ class ObjectBufferPool { /// This is the structure returned whenever an object chunk is /// accessed via Get and Create. struct ChunkInfo { - ChunkInfo(uint64_t chunk_index, uint8_t *data, uint64_t buffer_length, + ChunkInfo(uint64_t chunk_index, + uint8_t *data, + uint64_t buffer_length, std::shared_ptr buffer_ref) : chunk_index(chunk_index), data(data), @@ -106,8 +108,10 @@ class ObjectBufferPool { /// An IOError status is returned if object creation on the store client fails, /// or if create is invoked consecutively on the same chunk /// (with no intermediate AbortCreateChunk). - ray::Status CreateChunk(const ObjectID &object_id, const rpc::Address &owner_address, - uint64_t data_size, uint64_t metadata_size, + ray::Status CreateChunk(const ObjectID &object_id, + const rpc::Address &owner_address, + uint64_t data_size, + uint64_t metadata_size, uint64_t chunk_index) LOCKS_EXCLUDED(pool_mutex_); /// Write to a Chunk of an object. If all chunks of an object is written, @@ -120,9 +124,11 @@ class ObjectBufferPool { /// \param object_id The ObjectID. /// \param chunk_index The index of the chunk. /// \param data The data to write into the chunk. - void WriteChunk(const ObjectID &object_id, uint64_t data_size, uint64_t metadata_size, - uint64_t chunk_index, const std::string &data) - LOCKS_EXCLUDED(pool_mutex_); + void WriteChunk(const ObjectID &object_id, + uint64_t data_size, + uint64_t metadata_size, + uint64_t chunk_index, + const std::string &data) LOCKS_EXCLUDED(pool_mutex_); /// Free a list of objects from object store. /// @@ -142,7 +148,8 @@ class ObjectBufferPool { private: /// Splits an object into ceil(data_size/chunk_size) chunks, which will /// either be read or written to in parallel. - std::vector BuildChunks(const ObjectID &object_id, uint8_t *data, + std::vector BuildChunks(const ObjectID &object_id, + uint8_t *data, uint64_t data_size, std::shared_ptr buffer_ref) EXCLUSIVE_LOCKS_REQUIRED(pool_mutex_); @@ -152,8 +159,10 @@ class ObjectBufferPool { /// Must hold pool_mutex_ when calling this function. pool_mutex_ can be released /// during the call. ray::Status EnsureBufferExists(const ObjectID &object_id, - const rpc::Address &owner_address, uint64_t data_size, - uint64_t metadata_size, uint64_t chunk_index) + const rpc::Address &owner_address, + uint64_t data_size, + uint64_t metadata_size, + uint64_t chunk_index) EXCLUSIVE_LOCKS_REQUIRED(pool_mutex_); void AbortCreateInternal(const ObjectID &object_id) @@ -164,7 +173,8 @@ class ObjectBufferPool { /// Holds the state of creating chunks. Members are protected by pool_mutex_. struct CreateBufferState { - CreateBufferState(uint64_t metadata_size, uint64_t data_size, + CreateBufferState(uint64_t metadata_size, + uint64_t data_size, std::vector chunk_info) : metadata_size(metadata_size), data_size(data_size), diff --git a/src/ray/object_manager/object_directory.h b/src/ray/object_manager/object_directory.h index 1a37d732d..f12a7e6e3 100644 --- a/src/ray/object_manager/object_directory.h +++ b/src/ray/object_manager/object_directory.h @@ -42,9 +42,12 @@ struct RemoteConnectionInfo { }; /// Callback for object location notifications. -using OnLocationsFound = std::function &, - const std::string &, const NodeID &, bool pending_creation, size_t object_size)>; +using OnLocationsFound = std::function &, + const std::string &, + const NodeID &, + bool pending_creation, + size_t object_size)>; class IObjectDirectory { public: @@ -113,7 +116,8 @@ class IObjectDirectory { /// \param object_id The object id that was put into the store. /// \param node_id The node id corresponding to this node. /// \param object_info Additional information about the object. - virtual void ReportObjectAdded(const ObjectID &object_id, const NodeID &node_id, + virtual void ReportObjectAdded(const ObjectID &object_id, + const NodeID &node_id, const ObjectInfo &object_info) = 0; /// Report objects removed from this node's store to the object directory. @@ -121,7 +125,8 @@ class IObjectDirectory { /// \param object_id The object id that was removed from the store. /// \param node_id The node id corresponding to this node. /// \param object_info Additional information about the object. - virtual void ReportObjectRemoved(const ObjectID &object_id, const NodeID &node_id, + virtual void ReportObjectRemoved(const ObjectID &object_id, + const NodeID &node_id, const ObjectInfo &object_info) = 0; /// Record metrics. diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index d5830ffc4..fdeb516fc 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -29,14 +29,19 @@ ObjectStoreRunner::ObjectStoreRunner(const ObjectManagerConfig &config, std::function object_store_full_callback, AddObjectCallback add_object_callback, DeleteObjectCallback delete_object_callback) { - plasma::plasma_store_runner.reset(new plasma::PlasmaStoreRunner( - config.store_socket_name, config.object_store_memory, config.huge_pages, - config.plasma_directory, config.fallback_directory)); + plasma::plasma_store_runner.reset( + new plasma::PlasmaStoreRunner(config.store_socket_name, + config.object_store_memory, + config.huge_pages, + config.plasma_directory, + config.fallback_directory)); // Initialize object store. - store_thread_ = - std::thread(&plasma::PlasmaStoreRunner::Start, plasma::plasma_store_runner.get(), - spill_objects_callback, object_store_full_callback, add_object_callback, - delete_object_callback); + store_thread_ = std::thread(&plasma::PlasmaStoreRunner::Start, + plasma::plasma_store_runner.get(), + spill_objects_callback, + object_store_full_callback, + add_object_callback, + delete_object_callback); // Sleep for sometime until the store is working. This can suppress some // connection warnings. std::this_thread::sleep_for(std::chrono::microseconds(500)); @@ -49,13 +54,16 @@ ObjectStoreRunner::~ObjectStoreRunner() { } ObjectManager::ObjectManager( - instrumented_io_context &main_service, const NodeID &self_node_id, - const ObjectManagerConfig &config, IObjectDirectory *object_directory, + instrumented_io_context &main_service, + const NodeID &self_node_id, + const ObjectManagerConfig &config, + IObjectDirectory *object_directory, RestoreSpilledObjectCallback restore_spilled_object, std::function get_spilled_object_url, SpillObjectsCallback spill_objects_callback, std::function object_store_full_callback, - AddObjectCallback add_object_callback, DeleteObjectCallback delete_object_callback, + AddObjectCallback add_object_callback, + DeleteObjectCallback delete_object_callback, std::function(const ObjectID &object_id)> pin_object, const std::function fail_pull_request) : main_service_(&main_service), @@ -63,12 +71,15 @@ ObjectManager::ObjectManager( config_(config), object_directory_(object_directory), object_store_internal_( - config, spill_objects_callback, object_store_full_callback, + config, + spill_objects_callback, + object_store_full_callback, /*add_object_callback=*/ - [this, add_object_callback = - std::move(add_object_callback)](const ObjectInfo &object_info) { + [this, add_object_callback = std::move(add_object_callback)]( + const ObjectInfo &object_info) { main_service_->post( - [this, object_info, + [this, + object_info, add_object_callback = std::move(add_object_callback)]() { HandleObjectAdded(object_info); add_object_callback(object_info); @@ -76,10 +87,11 @@ ObjectManager::ObjectManager( "ObjectManager.ObjectAdded"); }, /*delete_object_callback=*/ - [this, delete_object_callback = - std::move(delete_object_callback)](const ObjectID &object_id) { + [this, delete_object_callback = std::move(delete_object_callback)]( + const ObjectID &object_id) { main_service_->post( - [this, object_id, + [this, + object_id, delete_object_callback = std::move(delete_object_callback)]() { HandleObjectDeleted(object_id); delete_object_callback(object_id); @@ -89,7 +101,8 @@ ObjectManager::ObjectManager( buffer_pool_store_client_(std::make_shared()), buffer_pool_(buffer_pool_store_client_, config_.object_chunk_size), rpc_work_(rpc_service_), - object_manager_server_("ObjectManager", config_.object_manager_port, + object_manager_server_("ObjectManager", + config_.object_manager_port, config_.object_manager_address == "127.0.0.1", config_.rpc_service_threads_number), object_manager_service_(rpc_service_, *this), @@ -124,10 +137,17 @@ ObjectManager::ObjectManager( if (available_memory < 0) { available_memory = 0; } - pull_manager_.reset(new PullManager( - self_node_id_, object_is_local, send_pull_request, cancel_pull_request, - fail_pull_request, restore_spilled_object_, get_time, config.pull_timeout_ms, - available_memory, pin_object, get_spilled_object_url)); + pull_manager_.reset(new PullManager(self_node_id_, + object_is_local, + send_pull_request, + cancel_pull_request, + fail_pull_request, + restore_spilled_object_, + get_time, + config.pull_timeout_ms, + available_memory, + pin_object, + get_spilled_object_url)); RAY_CHECK_OK( buffer_pool_store_client_->Connect(config_.store_socket_name.c_str(), "", 0, 300)); @@ -214,13 +234,19 @@ uint64_t ObjectManager::Pull(const std::vector &object_ref std::vector objects_to_locate; auto request_id = pull_manager_->Pull(object_refs, prio, &objects_to_locate); - const auto &callback = - [this](const ObjectID &object_id, const std::unordered_set &client_ids, - const std::string &spilled_url, const NodeID &spilled_node_id, - bool pending_creation, size_t object_size) { - pull_manager_->OnLocationChange(object_id, client_ids, spilled_url, - spilled_node_id, pending_creation, object_size); - }; + const auto &callback = [this](const ObjectID &object_id, + const std::unordered_set &client_ids, + const std::string &spilled_url, + const NodeID &spilled_node_id, + bool pending_creation, + size_t object_size) { + pull_manager_->OnLocationChange(object_id, + client_ids, + spilled_url, + spilled_node_id, + pending_creation, + object_size); + }; for (const auto &ref : objects_to_locate) { // Subscribe to object notifications. A notification will be received every @@ -291,9 +317,12 @@ void ObjectManager::HandlePushTaskTimeout(const ObjectID &object_id, } } -void ObjectManager::HandleSendFinished(const ObjectID &object_id, const NodeID &node_id, - uint64_t chunk_index, double start_time, - double end_time, ray::Status status) { +void ObjectManager::HandleSendFinished(const ObjectID &object_id, + const NodeID &node_id, + uint64_t chunk_index, + double start_time, + double end_time, + ray::Status status) { RAY_LOG(DEBUG) << "HandleSendFinished on " << self_node_id_ << " to " << node_id << " of object " << object_id << " chunk " << chunk_index << ", status: " << status.ToString(); @@ -385,12 +414,14 @@ void ObjectManager::PushLocalObject(const ObjectID &object_id, const NodeID &nod local_objects_[object_id].object_info.metadata_size = 1; } - PushObjectInternal(object_id, node_id, + PushObjectInternal(object_id, + node_id, std::make_shared(std::move(object_reader), config_.object_chunk_size)); } -void ObjectManager::PushFromFilesystem(const ObjectID &object_id, const NodeID &node_id, +void ObjectManager::PushFromFilesystem(const ObjectID &object_id, + const NodeID &node_id, const std::string &spilled_url) { // SpilledObjectReader::CreateSpilledObjectReader does synchronous IO; schedule it off // main thread. @@ -411,7 +442,9 @@ void ObjectManager::PushFromFilesystem(const ObjectID &object_id, const NodeID & // Schedule PushObjectInternal back to main_service as PushObjectInternal access // thread unsafe datastructure. main_service_->post( - [this, object_id, node_id, + [this, + object_id, + node_id, chunk_object_reader = std::move(chunk_object_reader)]() { PushObjectInternal(object_id, node_id, std::move(chunk_object_reader)); }, @@ -420,7 +453,8 @@ void ObjectManager::PushFromFilesystem(const ObjectID &object_id, const NodeID & "ObjectManager.CreateSpilledObject"); } -void ObjectManager::PushObjectInternal(const ObjectID &object_id, const NodeID &node_id, +void ObjectManager::PushObjectInternal(const ObjectID &object_id, + const NodeID &node_id, std::shared_ptr chunk_reader) { auto rpc_client = GetRpcClient(node_id); if (!rpc_client) { @@ -442,7 +476,11 @@ void ObjectManager::PushObjectInternal(const ObjectID &object_id, const NodeID & // Post to the multithreaded RPC event loop so that data is copied // off of the main thread. SendObjectChunk( - push_id, object_id, node_id, chunk_id, rpc_client, + push_id, + object_id, + node_id, + chunk_id, + rpc_client, [=](const Status &status) { // Post back to the main event loop because the // PushManager is thread-safe. @@ -458,8 +496,10 @@ void ObjectManager::PushObjectInternal(const ObjectID &object_id, const NodeID & }); } -void ObjectManager::SendObjectChunk(const UniqueID &push_id, const ObjectID &object_id, - const NodeID &node_id, uint64_t chunk_index, +void ObjectManager::SendObjectChunk(const UniqueID &push_id, + const ObjectID &object_id, + const NodeID &node_id, + uint64_t chunk_index, std::shared_ptr rpc_client, std::function on_complete, std::shared_ptr chunk_reader) { @@ -504,7 +544,8 @@ void ObjectManager::SendObjectChunk(const UniqueID &push_id, const ObjectID &obj } /// Implementation of ObjectManagerServiceHandler -void ObjectManager::HandlePush(const rpc::PushRequest &request, rpc::PushReply *reply, +void ObjectManager::HandlePush(const rpc::PushRequest &request, + rpc::PushReply *reply, rpc::SendReplyCallback send_reply_callback) { ObjectID object_id = ObjectID::FromBinary(request.object_id()); NodeID node_id = NodeID::FromBinary(request.node_id()); @@ -516,8 +557,8 @@ void ObjectManager::HandlePush(const rpc::PushRequest &request, rpc::PushReply * const rpc::Address &owner_address = request.owner_address(); const std::string &data = request.data(); - bool success = ReceiveObjectChunk(node_id, object_id, owner_address, data_size, - metadata_size, chunk_index, data); + bool success = ReceiveObjectChunk( + node_id, object_id, owner_address, data_size, metadata_size, chunk_index, data); num_chunks_received_total_++; if (!success) { num_chunks_received_total_failed_++; @@ -530,10 +571,13 @@ void ObjectManager::HandlePush(const rpc::PushRequest &request, rpc::PushReply * send_reply_callback(Status::OK(), nullptr, nullptr); } -bool ObjectManager::ReceiveObjectChunk(const NodeID &node_id, const ObjectID &object_id, +bool ObjectManager::ReceiveObjectChunk(const NodeID &node_id, + const ObjectID &object_id, const rpc::Address &owner_address, - uint64_t data_size, uint64_t metadata_size, - uint64_t chunk_index, const std::string &data) { + uint64_t data_size, + uint64_t metadata_size, + uint64_t chunk_index, + const std::string &data) { RAY_LOG(DEBUG) << "ReceiveObjectChunk on " << self_node_id_ << " from " << node_id << " of object " << object_id << " chunk index: " << chunk_index << ", chunk data size: " << data.size() @@ -544,8 +588,8 @@ bool ObjectManager::ReceiveObjectChunk(const NodeID &node_id, const ObjectID &ob // This object is no longer being actively pulled. Do not create the object. return false; } - auto chunk_status = buffer_pool_.CreateChunk(object_id, owner_address, data_size, - metadata_size, chunk_index); + auto chunk_status = buffer_pool_.CreateChunk( + object_id, owner_address, data_size, metadata_size, chunk_index); if (!pull_manager_->IsObjectActive(object_id)) { num_chunks_received_cancelled_++; // This object is no longer being actively pulled. Abort the object. We @@ -567,7 +611,8 @@ bool ObjectManager::ReceiveObjectChunk(const NodeID &node_id, const ObjectID &ob } } -void ObjectManager::HandlePull(const rpc::PullRequest &request, rpc::PullReply *reply, +void ObjectManager::HandlePull(const rpc::PullRequest &request, + rpc::PullReply *reply, rpc::SendReplyCallback send_reply_callback) { ObjectID object_id = ObjectID::FromBinary(request.object_id()); NodeID node_id = NodeID::FromBinary(request.node_id()); @@ -620,12 +665,14 @@ void ObjectManager::SpreadFreeObjectsRequest( } for (auto &rpc_client : rpc_clients) { - rpc_client->FreeObjects(free_objects_request, [](const Status &status, - const rpc::FreeObjectsReply &reply) { - if (!status.ok()) { - RAY_LOG(WARNING) << "Send free objects request failed due to" << status.message(); - } - }); + rpc_client->FreeObjects(free_objects_request, + [](const Status &status, const rpc::FreeObjectsReply &reply) { + if (!status.ok()) { + RAY_LOG(WARNING) + << "Send free objects request failed due to" + << status.message(); + } + }); } } diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index eb0aa4f46..be490cff1 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -127,7 +127,8 @@ class ObjectManager : public ObjectManagerInterface, /// \param request Push request including the object chunk data /// \param reply Reply to the sender /// \param send_reply_callback Callback of the request - void HandlePush(const rpc::PushRequest &request, rpc::PushReply *reply, + void HandlePush(const rpc::PushRequest &request, + rpc::PushReply *reply, rpc::SendReplyCallback send_reply_callback) override; /// Handle pull request from remote object manager @@ -135,7 +136,8 @@ class ObjectManager : public ObjectManagerInterface, /// \param request Pull request /// \param reply Reply /// \param send_reply_callback Callback of request - void HandlePull(const rpc::PullRequest &request, rpc::PullReply *reply, + void HandlePull(const rpc::PullRequest &request, + rpc::PullReply *reply, rpc::SendReplyCallback send_reply_callback) override; /// Handle free objects request @@ -163,13 +165,16 @@ class ObjectManager : public ObjectManagerInterface, /// \param config ObjectManager configuration. /// \param object_directory An object implementing the object directory interface. explicit ObjectManager( - instrumented_io_context &main_service, const NodeID &self_node_id, - const ObjectManagerConfig &config, IObjectDirectory *object_directory, + instrumented_io_context &main_service, + const NodeID &self_node_id, + const ObjectManagerConfig &config, + IObjectDirectory *object_directory, RestoreSpilledObjectCallback restore_spilled_object, std::function get_spilled_object_url, SpillObjectsCallback spill_objects_callback, std::function object_store_full_callback, - AddObjectCallback add_object_callback, DeleteObjectCallback delete_object_callback, + AddObjectCallback add_object_callback, + DeleteObjectCallback delete_object_callback, std::function(const ObjectID &object_id)> pin_object, const std::function fail_pull_request); @@ -266,7 +271,8 @@ class ObjectManager : public ObjectManagerInterface, /// \param node_id The remote node's id. /// \param spilled_url The url of the spilled object. /// \return Void. - void PushFromFilesystem(const ObjectID &object_id, const NodeID &node_id, + void PushFromFilesystem(const ObjectID &object_id, + const NodeID &node_id, const std::string &spilled_url); /// The internal implementation of pushing an object. @@ -275,7 +281,8 @@ class ObjectManager : public ObjectManagerInterface, /// \param node_id The remote node's id. /// \param chunk_reader Chunk reader used to read a chunk of the object /// Status::OK() if the read succeeded. - void PushObjectInternal(const ObjectID &object_id, const NodeID &node_id, + void PushObjectInternal(const ObjectID &object_id, + const NodeID &node_id, std::shared_ptr chunk_reader); /// Send one chunk of the object to remote object manager @@ -289,8 +296,10 @@ class ObjectManager : public ObjectManagerInterface, /// \param rpc_client Rpc client used to send message to remote object manager /// \param on_complete Callback when the chunk is sent /// \param chunk_reader Chunk reader used to read a chunk of the object - void SendObjectChunk(const UniqueID &push_id, const ObjectID &object_id, - const NodeID &node_id, uint64_t chunk_index, + void SendObjectChunk(const UniqueID &push_id, + const ObjectID &object_id, + const NodeID &node_id, + uint64_t chunk_index, std::shared_ptr rpc_client, std::function on_complete, std::shared_ptr chunk_reader); @@ -322,8 +331,11 @@ class ObjectManager : public ObjectManagerInterface, /// chunk. /// \param status The status of the send (e.g., did it succeed or fail). /// \return Void. - void HandleSendFinished(const ObjectID &object_id, const NodeID &node_id, - uint64_t chunk_index, double start_time_us, double end_time_us, + void HandleSendFinished(const ObjectID &object_id, + const NodeID &node_id, + uint64_t chunk_index, + double start_time_us, + double end_time_us, ray::Status status); /// Handle Push task timeout. @@ -349,9 +361,12 @@ class ObjectManager : public ObjectManagerInterface, /// \return Whether the chunk was successfully written into the local object /// store. This can fail if the chunk was already received in the past, or if /// the object is no longer being actively pulled. - bool ReceiveObjectChunk(const NodeID &node_id, const ObjectID &object_id, - const rpc::Address &owner_address, uint64_t data_size, - uint64_t metadata_size, uint64_t chunk_index, + bool ReceiveObjectChunk(const NodeID &node_id, + const ObjectID &object_id, + const rpc::Address &owner_address, + uint64_t data_size, + uint64_t metadata_size, + uint64_t chunk_index, const std::string &data); /// Send pull request @@ -406,7 +421,8 @@ class ObjectManager : public ObjectManagerInterface, /// Maintains a map of push requests that have not been fulfilled due to an object not /// being local. Objects are removed from this map after push_timeout_ms have elapsed. std::unordered_map< - ObjectID, std::unordered_map>> + ObjectID, + std::unordered_map>> unfulfilled_push_requests_; /// The gPRC server. diff --git a/src/ray/object_manager/object_reader.h b/src/ray/object_manager/object_reader.h index 96f431c5e..295e917ad 100644 --- a/src/ray/object_manager/object_reader.h +++ b/src/ray/object_manager/object_reader.h @@ -40,7 +40,8 @@ class IObjectReader { /// \param size number of bytes to copy. /// \param output pointer to the memory location to copy to. /// \return bool. - virtual bool ReadFromDataSection(uint64_t offset, uint64_t size, + virtual bool ReadFromDataSection(uint64_t offset, + uint64_t size, char *output) const = 0; /// Read from metadata sections into output. /// Return false if the object is corrupted or size/offset is invalid. @@ -49,7 +50,8 @@ class IObjectReader { /// \param size number of bytes to copy. /// \param output pointer to the memory location to copy to. /// \return bool. - virtual bool ReadFromMetadataSection(uint64_t offset, uint64_t size, + virtual bool ReadFromMetadataSection(uint64_t offset, + uint64_t size, char *output) const = 0; }; } // namespace ray diff --git a/src/ray/object_manager/ownership_based_object_directory.cc b/src/ray/object_manager/ownership_based_object_directory.cc index bae8a86c1..2e2a6df61 100644 --- a/src/ray/object_manager/ownership_based_object_directory.cc +++ b/src/ray/object_manager/ownership_based_object_directory.cc @@ -19,9 +19,11 @@ namespace ray { OwnershipBasedObjectDirectory::OwnershipBasedObjectDirectory( - instrumented_io_context &io_service, std::shared_ptr &gcs_client, + instrumented_io_context &io_service, + std::shared_ptr &gcs_client, pubsub::SubscriberInterface *object_location_subscriber, - rpc::CoreWorkerClientPool *owner_client_pool, int64_t max_object_report_batch_size, + rpc::CoreWorkerClientPool *owner_client_pool, + int64_t max_object_report_batch_size, std::function mark_as_failed) : io_service_(io_service), gcs_client_(gcs_client), @@ -48,8 +50,10 @@ void FilterRemovedNodes(std::shared_ptr gcs_client, /// Update object location data based on response from the owning core worker. bool UpdateObjectLocations(const rpc::WorkerObjectLocationsPubMessage &location_info, std::shared_ptr gcs_client, - std::unordered_set *node_ids, std::string *spilled_url, - NodeID *spilled_node_id, bool *pending_creation, + std::unordered_set *node_ids, + std::string *spilled_url, + NodeID *spilled_node_id, + bool *pending_creation, size_t *object_size) { bool is_updated = false; std::unordered_set new_node_ids; @@ -188,8 +192,9 @@ void OwnershipBasedObjectDirectory::SendObjectLocationUpdateBatchIfNeeded( in_flight_requests_.emplace(worker_id); auto owner_client = GetClient(owner_address); owner_client->UpdateObjectLocationBatch( - request, [this, worker_id, node_id, owner_address]( - Status status, const rpc::UpdateObjectLocationBatchReply &reply) { + request, + [this, worker_id, node_id, owner_address]( + Status status, const rpc::UpdateObjectLocationBatchReply &reply) { auto in_flight_request_it = in_flight_requests_.find(worker_id); RAY_CHECK(in_flight_request_it != in_flight_requests_.end()); in_flight_requests_.erase(in_flight_request_it); @@ -216,7 +221,8 @@ void OwnershipBasedObjectDirectory::SendObjectLocationUpdateBatchIfNeeded( } void OwnershipBasedObjectDirectory::ObjectLocationSubscriptionCallback( - const rpc::WorkerObjectLocationsPubMessage &location_info, const ObjectID &object_id, + const rpc::WorkerObjectLocationsPubMessage &location_info, + const ObjectID &object_id, bool location_lookup_failed) { // Objects are added to this map in SubscribeObjectLocations. auto it = listeners_.find(object_id); @@ -233,10 +239,13 @@ void OwnershipBasedObjectDirectory::ObjectLocationSubscriptionCallback( RAY_LOG(DEBUG) << "Object " << object_id << " is on node " << node_id << " alive? " << !gcs_client_->Nodes().IsRemoved(node_id); } - auto location_updated = UpdateObjectLocations( - location_info, gcs_client_, &it->second.current_object_locations, - &it->second.spilled_url, &it->second.spilled_node_id, &it->second.pending_creation, - &it->second.object_size); + auto location_updated = UpdateObjectLocations(location_info, + gcs_client_, + &it->second.current_object_locations, + &it->second.spilled_url, + &it->second.spilled_node_id, + &it->second.pending_creation, + &it->second.object_size); // If the lookup has failed, that means the object is lost. Trigger the callback in this // case to handle failure properly. @@ -260,16 +269,21 @@ void OwnershipBasedObjectDirectory::ObjectLocationSubscriptionCallback( // We can call the callback directly without worrying about invalidating caller // iterators since this is already running in the subscription callback stack. // See https://github.com/ray-project/ray/issues/2959. - callback_pair.second(object_id, it->second.current_object_locations, - it->second.spilled_url, it->second.spilled_node_id, - it->second.pending_creation, it->second.object_size); + callback_pair.second(object_id, + it->second.current_object_locations, + it->second.spilled_url, + it->second.spilled_node_id, + it->second.pending_creation, + it->second.object_size); } } } ray::Status OwnershipBasedObjectDirectory::SubscribeObjectLocations( - const UniqueID &callback_id, const ObjectID &object_id, - const rpc::Address &owner_address, const OnLocationsFound &callback) { + const UniqueID &callback_id, + const ObjectID &object_id, + const rpc::Address &owner_address, + const OnLocationsFound &callback) { auto it = listeners_.find(object_id); if (it == listeners_.end()) { // Create an object eviction subscription message. @@ -281,7 +295,8 @@ ray::Status OwnershipBasedObjectDirectory::SubscribeObjectLocations( RAY_CHECK(pub_message.has_worker_object_locations_message()); const auto &location_info = pub_message.worker_object_locations_message(); ObjectLocationSubscriptionCallback( - location_info, object_id, + location_info, + object_id, /*location_lookup_failed*/ !location_info.ref_removed()); }; @@ -304,7 +319,8 @@ ray::Status OwnershipBasedObjectDirectory::SubscribeObjectLocations( // Location lookup can fail if the owner is reachable but no longer has a // record of this ObjectRef, most likely due to an issue with the // distributed reference counting protocol. - ObjectLocationSubscriptionCallback(location_info, object_id, + ObjectLocationSubscriptionCallback(location_info, + object_id, /*location_lookup_failed*/ true); }; @@ -312,8 +328,11 @@ ray::Status OwnershipBasedObjectDirectory::SubscribeObjectLocations( sub_message->mutable_worker_object_locations_message()->Swap(request.get()); RAY_CHECK(object_location_subscriber_->Subscribe( - std::move(sub_message), rpc::ChannelType::WORKER_OBJECT_LOCATIONS_CHANNEL, - owner_address, object_id.Binary(), /*subscribe_done_callback=*/nullptr, + std::move(sub_message), + rpc::ChannelType::WORKER_OBJECT_LOCATIONS_CHANNEL, + owner_address, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, /*Success callback=*/msg_published_callback, /*Failure callback=*/failure_callback)); @@ -346,9 +365,18 @@ ray::Status OwnershipBasedObjectDirectory::SubscribeObjectLocations( // structures shared with the caller and potentially invalidating caller // iterators. See https://github.com/ray-project/ray/issues/2959. io_service_.post( - [callback, locations, spilled_url, spilled_node_id, pending_creation, object_size, + [callback, + locations, + spilled_url, + spilled_node_id, + pending_creation, + object_size, object_id]() { - callback(object_id, locations, spilled_url, spilled_node_id, pending_creation, + callback(object_id, + locations, + spilled_url, + spilled_node_id, + pending_creation, object_size); }, "ObjectDirectory.SubscribeObjectLocations"); @@ -365,7 +393,8 @@ ray::Status OwnershipBasedObjectDirectory::UnsubscribeObjectLocations( entry->second.callbacks.erase(callback_id); if (entry->second.callbacks.empty()) { object_location_subscriber_->Unsubscribe( - rpc::ChannelType::WORKER_OBJECT_LOCATIONS_CHANNEL, entry->second.owner_address, + rpc::ChannelType::WORKER_OBJECT_LOCATIONS_CHANNEL, + entry->second.owner_address, object_id.Binary()); owner_client_pool_->Disconnect( WorkerID::FromBinary(entry->second.owner_address.worker_id())); @@ -375,7 +404,8 @@ ray::Status OwnershipBasedObjectDirectory::UnsubscribeObjectLocations( } ray::Status OwnershipBasedObjectDirectory::LookupLocations( - const ObjectID &object_id, const rpc::Address &owner_address, + const ObjectID &object_id, + const rpc::Address &owner_address, const OnLocationsFound &callback) { metrics_num_object_location_lookups_++; auto it = listeners_.find(object_id); @@ -393,9 +423,18 @@ ray::Status OwnershipBasedObjectDirectory::LookupLocations( // structures shared with the caller and potentially invalidating caller // iterators. See https://github.com/ray-project/ray/issues/2959. io_service_.post( - [callback, object_id, locations, spilled_url, spilled_node_id, pending_creation, + [callback, + object_id, + locations, + spilled_url, + spilled_node_id, + pending_creation, object_size]() { - callback(object_id, locations, spilled_url, spilled_node_id, pending_creation, + callback(object_id, + locations, + spilled_url, + spilled_node_id, + pending_creation, object_size); }, "ObjectDirectory.LookupLocations"); @@ -410,8 +449,12 @@ ray::Status OwnershipBasedObjectDirectory::LookupLocations( // See https://github.com/ray-project/ray/issues/2959. io_service_.post( [callback, object_id]() { - callback(object_id, std::unordered_set(), "", NodeID::Nil(), - /*pending_creation=*/false, 0); + callback(object_id, + std::unordered_set(), + "", + NodeID::Nil(), + /*pending_creation=*/false, + 0); }, "ObjectDirectory.LookupLocations"); return Status::OK(); @@ -423,8 +466,9 @@ ray::Status OwnershipBasedObjectDirectory::LookupLocations( object_location_request->set_object_id(object_id.Binary()); owner_client->GetObjectLocationsOwner( - request, [this, worker_id, object_id, callback]( - Status status, const rpc::GetObjectLocationsOwnerReply &reply) { + request, + [this, worker_id, object_id, callback]( + Status status, const rpc::GetObjectLocationsOwnerReply &reply) { std::unordered_set node_ids; std::string spilled_url; NodeID spilled_node_id; @@ -442,8 +486,12 @@ ray::Status OwnershipBasedObjectDirectory::LookupLocations( << ", object already released by distributed reference counting protocol"; mark_as_failed_(object_id, rpc::ErrorType::OBJECT_DELETED); } else { - UpdateObjectLocations(reply.object_location_info(), gcs_client_, &node_ids, - &spilled_url, &spilled_node_id, &pending_creation, + UpdateObjectLocations(reply.object_location_info(), + gcs_client_, + &node_ids, + &spilled_url, + &spilled_node_id, + &pending_creation, &object_size); } RAY_LOG(DEBUG) << "Looked up locations for " << object_id @@ -455,7 +503,11 @@ ray::Status OwnershipBasedObjectDirectory::LookupLocations( // caller iterators since this is already running in the core worker // client's lookup callback stack. // See https://github.com/ray-project/ray/issues/2959. - callback(object_id, node_ids, spilled_url, spilled_node_id, pending_creation, + callback(object_id, + node_ids, + spilled_url, + spilled_node_id, + pending_creation, object_size); }); } @@ -503,8 +555,10 @@ void OwnershipBasedObjectDirectory::HandleNodeRemoved(const NodeID &node_id) { for (const auto &callback_pair : listener.second.callbacks) { // It is safe to call the callback directly since this is already running // in the subscription callback stack. - callback_pair.second(object_id, listener.second.current_object_locations, - listener.second.spilled_url, listener.second.spilled_node_id, + callback_pair.second(object_id, + listener.second.current_object_locations, + listener.second.spilled_url, + listener.second.spilled_node_id, listener.second.pending_creation, listener.second.object_size); } diff --git a/src/ray/object_manager/ownership_based_object_directory.h b/src/ray/object_manager/ownership_based_object_directory.h index b79d44453..f7021e0ff 100644 --- a/src/ray/object_manager/ownership_based_object_directory.h +++ b/src/ray/object_manager/ownership_based_object_directory.h @@ -43,9 +43,11 @@ class OwnershipBasedObjectDirectory : public IObjectDirectory { /// \param gcs_client A Ray GCS client to request object and node /// information from. OwnershipBasedObjectDirectory( - instrumented_io_context &io_service, std::shared_ptr &gcs_client, + instrumented_io_context &io_service, + std::shared_ptr &gcs_client, pubsub::SubscriberInterface *object_location_subscriber, - rpc::CoreWorkerClientPool *owner_client_pool, int64_t max_object_report_batch_size, + rpc::CoreWorkerClientPool *owner_client_pool, + int64_t max_object_report_batch_size, std::function mark_as_failed); virtual ~OwnershipBasedObjectDirectory() {} @@ -69,12 +71,14 @@ class OwnershipBasedObjectDirectory : public IObjectDirectory { /// Report to the owner that the given object is added to the current node. /// This method guarantees ordering and batches requests. - void ReportObjectAdded(const ObjectID &object_id, const NodeID &node_id, + void ReportObjectAdded(const ObjectID &object_id, + const NodeID &node_id, const ObjectInfo &object_info) override; /// Report to the owner that the given object is removed to the current node. /// This method guarantees ordering and batches requests. - void ReportObjectRemoved(const ObjectID &object_id, const NodeID &node_id, + void ReportObjectRemoved(const ObjectID &object_id, + const NodeID &node_id, const ObjectInfo &object_info) override; void RecordMetrics(uint64_t duration_ms) override; @@ -139,7 +143,8 @@ class OwnershipBasedObjectDirectory : public IObjectDirectory { /// Internal callback function used by object location subscription. void ObjectLocationSubscriptionCallback( const rpc::WorkerObjectLocationsPubMessage &location_info, - const ObjectID &object_id, bool location_lookup_failed); + const ObjectID &object_id, + bool location_lookup_failed); /// Send object location update batch from the location_buffers_. /// We only allow 1 in-flight request per owner for the batch request diff --git a/src/ray/object_manager/plasma/client.cc b/src/ray/object_manager/plasma/client.cc index 3b3c54f5e..5de3fcc20 100644 --- a/src/ray/object_manager/plasma/client.cc +++ b/src/ray/object_manager/plasma/client.cc @@ -19,9 +19,9 @@ #include "ray/object_manager/plasma/client.h" -#include - #include +#include +#include #include #include #include @@ -29,8 +29,7 @@ #include #include -#include - +#include "absl/container/flat_hash_map.h" #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/ray_config.h" #include "ray/object_manager/plasma/connection.h" @@ -38,8 +37,6 @@ #include "ray/object_manager/plasma/protocol.h" #include "ray/object_manager/plasma/shared_memory.h" -#include "absl/container/flat_hash_map.h" - namespace fb = plasma::flatbuf; namespace plasma { @@ -56,7 +53,8 @@ class PlasmaBuffer : public SharedMemoryBuffer { public: ~PlasmaBuffer(); - PlasmaBuffer(std::shared_ptr client, const ObjectID &object_id, + PlasmaBuffer(std::shared_ptr client, + const ObjectID &object_id, const std::shared_ptr &buffer) : SharedMemoryBuffer(buffer, 0, buffer->Size()), client_(client), @@ -72,7 +70,8 @@ class PlasmaBuffer : public SharedMemoryBuffer { /// be called in the associated Seal call. class RAY_NO_EXPORT PlasmaMutableBuffer : public SharedMemoryBuffer { public: - PlasmaMutableBuffer(std::shared_ptr client, uint8_t *mutable_data, + PlasmaMutableBuffer(std::shared_ptr client, + uint8_t *mutable_data, int64_t data_size) : SharedMemoryBuffer(mutable_data, data_size), client_(client) {} @@ -105,32 +104,46 @@ class PlasmaClient::Impl : public std::enable_shared_from_this *data, fb::ObjectSource source, + const ray::rpc::Address &owner_address, + int64_t data_size, + const uint8_t *metadata, + int64_t metadata_size, + std::shared_ptr *data, + fb::ObjectSource source, int device_num = 0); - Status RetryCreate(const ObjectID &object_id, uint64_t request_id, - const uint8_t *metadata, uint64_t *retry_with_request_id, + Status RetryCreate(const ObjectID &object_id, + uint64_t request_id, + const uint8_t *metadata, + uint64_t *retry_with_request_id, std::shared_ptr *data); Status TryCreateImmediately(const ObjectID &object_id, - const ray::rpc::Address &owner_address, int64_t data_size, - const uint8_t *metadata, int64_t metadata_size, - std::shared_ptr *data, fb::ObjectSource source, + const ray::rpc::Address &owner_address, + int64_t data_size, + const uint8_t *metadata, + int64_t metadata_size, + std::shared_ptr *data, + fb::ObjectSource source, int device_num); - Status Get(const std::vector &object_ids, int64_t timeout_ms, - std::vector *object_buffers, bool is_from_worker); + Status Get(const std::vector &object_ids, + int64_t timeout_ms, + std::vector *object_buffers, + bool is_from_worker); - Status Get(const ObjectID *object_ids, int64_t num_objects, int64_t timeout_ms, - ObjectBuffer *object_buffers, bool is_from_worker); + Status Get(const ObjectID *object_ids, + int64_t num_objects, + int64_t timeout_ms, + ObjectBuffer *object_buffers, + bool is_from_worker); Status Release(const ObjectID &object_id); @@ -154,7 +167,8 @@ class PlasmaClient::Impl : public std::enable_shared_from_this *data); @@ -173,14 +187,18 @@ class PlasmaClient::Impl : public std::enable_shared_from_this( const ObjectID &, const std::shared_ptr &)> &wrap_buffer, - ObjectBuffer *object_buffers, bool is_from_worker); + ObjectBuffer *object_buffers, + bool is_from_worker); uint8_t *LookupMmappedFile(MEMFD_TYPE store_fd_val); - void IncrementObjectCount(const ObjectID &object_id, PlasmaObject *object, + void IncrementObjectCount(const ObjectID &object_id, + PlasmaObject *object, bool is_sealed); /// The boost::asio IO context for the client. @@ -253,7 +271,8 @@ bool PlasmaClient::Impl::IsInUse(const ObjectID &object_id) { } void PlasmaClient::Impl::IncrementObjectCount(const ObjectID &object_id, - PlasmaObject *object, bool is_sealed) { + PlasmaObject *object, + bool is_sealed) { // Increment the count of the object to track the fact that it is being used. // The corresponding decrement should happen in PlasmaClient::Release. auto elem = objects_in_use_.find(object_id); @@ -288,8 +307,12 @@ Status PlasmaClient::Impl::HandleCreateReply(const ObjectID &object_id, int64_t mmap_size; if (retry_with_request_id) { - RAY_RETURN_NOT_OK(ReadCreateReply(buffer.data(), buffer.size(), &id, - retry_with_request_id, &object, &store_fd, + RAY_RETURN_NOT_OK(ReadCreateReply(buffer.data(), + buffer.size(), + &id, + retry_with_request_id, + &object, + &store_fd, &mmap_size)); if (*retry_with_request_id > 0) { // The client should retry the request. @@ -297,8 +320,8 @@ Status PlasmaClient::Impl::HandleCreateReply(const ObjectID &object_id, } } else { uint64_t unused = 0; - RAY_RETURN_NOT_OK(ReadCreateReply(buffer.data(), buffer.size(), &id, &unused, &object, - &store_fd, &mmap_size)); + RAY_RETURN_NOT_OK(ReadCreateReply( + buffer.data(), buffer.size(), &id, &unused, &object, &store_fd, &mmap_size)); RAY_CHECK(unused == 0); } @@ -308,7 +331,8 @@ Status PlasmaClient::Impl::HandleCreateReply(const ObjectID &object_id, // The metadata should come right after the data. RAY_CHECK(object.metadata_offset == object.data_offset + object.data_size); *data = std::make_shared( - shared_from_this(), GetStoreFdAndMmap(store_fd, mmap_size) + object.data_offset, + shared_from_this(), + GetStoreFdAndMmap(store_fd, mmap_size) + object.data_offset, object.data_size); // If plasma_create is being called from a transfer, then we will not copy the // metadata here. The metadata will be written along with the data streamed @@ -333,17 +357,26 @@ Status PlasmaClient::Impl::HandleCreateReply(const ObjectID &object_id, return Status::OK(); } -Status PlasmaClient::Impl::CreateAndSpillIfNeeded( - const ObjectID &object_id, const ray::rpc::Address &owner_address, int64_t data_size, - const uint8_t *metadata, int64_t metadata_size, std::shared_ptr *data, - fb::ObjectSource source, int device_num) { +Status PlasmaClient::Impl::CreateAndSpillIfNeeded(const ObjectID &object_id, + const ray::rpc::Address &owner_address, + int64_t data_size, + const uint8_t *metadata, + int64_t metadata_size, + std::shared_ptr *data, + fb::ObjectSource source, + int device_num) { std::unique_lock guard(client_mutex_); uint64_t retry_with_request_id = 0; RAY_LOG(DEBUG) << "called plasma_create on conn " << store_conn_ << " with size " << data_size << " and metadata size " << metadata_size; - RAY_RETURN_NOT_OK(SendCreateRequest(store_conn_, object_id, owner_address, data_size, - metadata_size, source, device_num, + RAY_RETURN_NOT_OK(SendCreateRequest(store_conn_, + object_id, + owner_address, + data_size, + metadata_size, + source, + device_num, /*try_immediately=*/false)); Status status = HandleCreateReply(object_id, metadata, &retry_with_request_id, data); @@ -355,14 +388,15 @@ Status PlasmaClient::Impl::CreateAndSpillIfNeeded( guard.lock(); RAY_LOG(DEBUG) << "Retrying request for object " << object_id << " with request ID " << retry_with_request_id; - status = RetryCreate(object_id, retry_with_request_id, metadata, - &retry_with_request_id, data); + status = RetryCreate( + object_id, retry_with_request_id, metadata, &retry_with_request_id, data); } return status; } -Status PlasmaClient::Impl::RetryCreate(const ObjectID &object_id, uint64_t request_id, +Status PlasmaClient::Impl::RetryCreate(const ObjectID &object_id, + uint64_t request_id, const uint8_t *metadata, uint64_t *retry_with_request_id, std::shared_ptr *data) { @@ -371,25 +405,37 @@ Status PlasmaClient::Impl::RetryCreate(const ObjectID &object_id, uint64_t reque return HandleCreateReply(object_id, metadata, retry_with_request_id, data); } -Status PlasmaClient::Impl::TryCreateImmediately( - const ObjectID &object_id, const ray::rpc::Address &owner_address, int64_t data_size, - const uint8_t *metadata, int64_t metadata_size, std::shared_ptr *data, - fb::ObjectSource source, int device_num) { +Status PlasmaClient::Impl::TryCreateImmediately(const ObjectID &object_id, + const ray::rpc::Address &owner_address, + int64_t data_size, + const uint8_t *metadata, + int64_t metadata_size, + std::shared_ptr *data, + fb::ObjectSource source, + int device_num) { std::lock_guard guard(client_mutex_); RAY_LOG(DEBUG) << "called plasma_create on conn " << store_conn_ << " with size " << data_size << " and metadata size " << metadata_size; - RAY_RETURN_NOT_OK(SendCreateRequest(store_conn_, object_id, owner_address, data_size, - metadata_size, source, device_num, + RAY_RETURN_NOT_OK(SendCreateRequest(store_conn_, + object_id, + owner_address, + data_size, + metadata_size, + source, + device_num, /*try_immediately=*/true)); return HandleCreateReply(object_id, metadata, nullptr, data); } Status PlasmaClient::Impl::GetBuffers( - const ObjectID *object_ids, int64_t num_objects, int64_t timeout_ms, + const ObjectID *object_ids, + int64_t num_objects, + int64_t timeout_ms, const std::function( const ObjectID &, const std::shared_ptr &)> &wrap_buffer, - ObjectBuffer *object_buffers, bool is_from_worker) { + ObjectBuffer *object_buffers, + bool is_from_worker) { // Fill out the info for the objects that are already in use locally. bool all_present = true; for (int64_t i = 0; i < num_objects; ++i) { @@ -436,8 +482,8 @@ Status PlasmaClient::Impl::GetBuffers( // If we get here, then the objects aren't all currently in use by this // client, so we need to send a request to the plasma store. - RAY_RETURN_NOT_OK(SendGetRequest(store_conn_, &object_ids[0], num_objects, timeout_ms, - is_from_worker)); + RAY_RETURN_NOT_OK(SendGetRequest( + store_conn_, &object_ids[0], num_objects, timeout_ms, is_from_worker)); std::vector buffer; RAY_RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaGetReply, &buffer)); std::vector received_object_ids(num_objects); @@ -445,8 +491,13 @@ Status PlasmaClient::Impl::GetBuffers( PlasmaObject *object; std::vector store_fds; std::vector mmap_sizes; - RAY_RETURN_NOT_OK(ReadGetReply(buffer.data(), buffer.size(), received_object_ids.data(), - object_data.data(), num_objects, store_fds, mmap_sizes)); + RAY_RETURN_NOT_OK(ReadGetReply(buffer.data(), + buffer.size(), + received_object_ids.data(), + object_data.data(), + num_objects, + store_fds, + mmap_sizes)); // We mmap all of the file descriptors here so that we can avoid look them up // in the subsequent loop based on just the store file descriptor and without @@ -498,7 +549,8 @@ Status PlasmaClient::Impl::GetBuffers( } Status PlasmaClient::Impl::Get(const std::vector &object_ids, - int64_t timeout_ms, std::vector *out, + int64_t timeout_ms, + std::vector *out, bool is_from_worker) { std::lock_guard guard(client_mutex_); @@ -508,8 +560,8 @@ Status PlasmaClient::Impl::Get(const std::vector &object_ids, }; const size_t num_objects = object_ids.size(); *out = std::vector(num_objects); - return GetBuffers(&object_ids[0], num_objects, timeout_ms, wrap_buffer, &(*out)[0], - is_from_worker); + return GetBuffers( + &object_ids[0], num_objects, timeout_ms, wrap_buffer, &(*out)[0], is_from_worker); } Status PlasmaClient::Impl::MarkObjectUnused(const ObjectID &object_id) { @@ -666,7 +718,8 @@ Status PlasmaClient::Impl::Evict(int64_t num_bytes, int64_t &num_bytes_evicted) Status PlasmaClient::Impl::Connect(const std::string &store_socket_name, const std::string &manager_socket_name, - int release_delay, int num_retries) { + int release_delay, + int num_retries) { std::lock_guard guard(client_mutex_); /// The local stream socket that connects to store. @@ -718,34 +771,53 @@ PlasmaClient::PlasmaClient() : impl_(std::make_shared()) {} PlasmaClient::~PlasmaClient() {} Status PlasmaClient::Connect(const std::string &store_socket_name, - const std::string &manager_socket_name, int release_delay, + const std::string &manager_socket_name, + int release_delay, int num_retries) { - return impl_->Connect(store_socket_name, manager_socket_name, release_delay, - num_retries); + return impl_->Connect( + store_socket_name, manager_socket_name, release_delay, num_retries); } Status PlasmaClient::CreateAndSpillIfNeeded(const ObjectID &object_id, const ray::rpc::Address &owner_address, - int64_t data_size, const uint8_t *metadata, + int64_t data_size, + const uint8_t *metadata, int64_t metadata_size, std::shared_ptr *data, - fb::ObjectSource source, int device_num) { - return impl_->CreateAndSpillIfNeeded(object_id, owner_address, data_size, metadata, - metadata_size, data, source, device_num); + fb::ObjectSource source, + int device_num) { + return impl_->CreateAndSpillIfNeeded(object_id, + owner_address, + data_size, + metadata, + metadata_size, + data, + source, + device_num); } Status PlasmaClient::TryCreateImmediately(const ObjectID &object_id, const ray::rpc::Address &owner_address, - int64_t data_size, const uint8_t *metadata, + int64_t data_size, + const uint8_t *metadata, int64_t metadata_size, std::shared_ptr *data, - fb::ObjectSource source, int device_num) { - return impl_->TryCreateImmediately(object_id, owner_address, data_size, metadata, - metadata_size, data, source, device_num); + fb::ObjectSource source, + int device_num) { + return impl_->TryCreateImmediately(object_id, + owner_address, + data_size, + metadata, + metadata_size, + data, + source, + device_num); } -Status PlasmaClient::Get(const std::vector &object_ids, int64_t timeout_ms, - std::vector *object_buffers, bool is_from_worker) { +Status PlasmaClient::Get(const std::vector &object_ids, + int64_t timeout_ms, + std::vector *object_buffers, + bool is_from_worker) { return impl_->Get(object_ids, timeout_ms, object_buffers, is_from_worker); } diff --git a/src/ray/object_manager/plasma/client.h b/src/ray/object_manager/plasma/client.h index c15ecf75c..d466528ec 100644 --- a/src/ray/object_manager/plasma/client.h +++ b/src/ray/object_manager/plasma/client.h @@ -77,8 +77,10 @@ class PlasmaClientInterface { /// \param[out] object_buffers The object results. /// \param is_from_worker Whether or not if the Get request comes from a Ray workers. /// \return The return status. - virtual Status Get(const std::vector &object_ids, int64_t timeout_ms, - std::vector *object_buffers, bool is_from_worker) = 0; + virtual Status Get(const std::vector &object_ids, + int64_t timeout_ms, + std::vector *object_buffers, + bool is_from_worker) = 0; /// Seal an object in the object store. The object will be immutable after /// this @@ -125,7 +127,8 @@ class PlasmaClientInterface { /// be either sealed or aborted. virtual Status CreateAndSpillIfNeeded(const ObjectID &object_id, const ray::rpc::Address &owner_address, - int64_t data_size, const uint8_t *metadata, + int64_t data_size, + const uint8_t *metadata, int64_t metadata_size, std::shared_ptr *data, plasma::flatbuf::ObjectSource source, @@ -158,7 +161,8 @@ class PlasmaClient : public PlasmaClientInterface { /// \param num_retries number of attempts to connect to IPC socket, default 50 /// \return The return status. Status Connect(const std::string &store_socket_name, - const std::string &manager_socket_name = "", int release_delay = 0, + const std::string &manager_socket_name = "", + int release_delay = 0, int num_retries = -1); /// Create an object in the Plasma Store. Any metadata for this object must be @@ -188,10 +192,13 @@ class PlasmaClient : public PlasmaClientInterface { /// The returned object must be released once it is done with. It must also /// be either sealed or aborted. Status CreateAndSpillIfNeeded(const ObjectID &object_id, - const ray::rpc::Address &owner_address, int64_t data_size, - const uint8_t *metadata, int64_t metadata_size, + const ray::rpc::Address &owner_address, + int64_t data_size, + const uint8_t *metadata, + int64_t metadata_size, std::shared_ptr *data, - plasma::flatbuf::ObjectSource source, int device_num = 0); + plasma::flatbuf::ObjectSource source, + int device_num = 0); /// Create an object in the Plasma Store. Any metadata for this object must be /// be passed in when the object is created. @@ -220,10 +227,13 @@ class PlasmaClient : public PlasmaClientInterface { /// The returned object must be released once it is done with. It must also /// be either sealed or aborted. Status TryCreateImmediately(const ObjectID &object_id, - const ray::rpc::Address &owner_address, int64_t data_size, - const uint8_t *metadata, int64_t metadata_size, + const ray::rpc::Address &owner_address, + int64_t data_size, + const uint8_t *metadata, + int64_t metadata_size, std::shared_ptr *data, - plasma::flatbuf::ObjectSource source, int device_num = 0); + plasma::flatbuf::ObjectSource source, + int device_num = 0); /// Get some objects from the Plasma Store. This function will block until the /// objects have all been created and sealed in the Plasma Store or the @@ -240,8 +250,10 @@ class PlasmaClient : public PlasmaClientInterface { /// \param[out] object_buffers The object results. /// \param is_from_worker Whether or not if the Get request comes from a Ray workers. /// \return The return status. - Status Get(const std::vector &object_ids, int64_t timeout_ms, - std::vector *object_buffers, bool is_from_worker); + Status Get(const std::vector &object_ids, + int64_t timeout_ms, + std::vector *object_buffers, + bool is_from_worker); /// Tell Plasma that the client no longer needs the object. This should be /// called after Get() or Create() when the client is done with the object. @@ -334,8 +346,10 @@ class PlasmaClient : public PlasmaClientInterface { /// \param retry_with_request_id If the request is not yet fulfilled, this /// will be set to a unique ID with which the client should retry. /// \param data The address of the newly created object will be written here. - Status RetryCreate(const ObjectID &object_id, uint64_t request_id, - const uint8_t *metadata, uint64_t *retry_with_request_id, + Status RetryCreate(const ObjectID &object_id, + uint64_t request_id, + const uint8_t *metadata, + uint64_t *retry_with_request_id, std::shared_ptr *data); friend class PlasmaBuffer; diff --git a/src/ray/object_manager/plasma/common.h b/src/ray/object_manager/plasma/common.h index 6297a2f60..de77dd592 100644 --- a/src/ray/object_manager/plasma/common.h +++ b/src/ray/object_manager/plasma/common.h @@ -68,7 +68,11 @@ struct Allocation { private: // Only created by Allocator - Allocation(void *address, int64_t size, MEMFD_TYPE fd, ptrdiff_t offset, int device_num, + Allocation(void *address, + int64_t size, + MEMFD_TYPE fd, + ptrdiff_t offset, + int device_num, int64_t mmap_size) : address(address), size(size), diff --git a/src/ray/object_manager/plasma/compat.h b/src/ray/object_manager/plasma/compat.h index 835ce75b4..b367a1809 100644 --- a/src/ray/object_manager/plasma/compat.h +++ b/src/ray/object_manager/plasma/compat.h @@ -27,6 +27,7 @@ #include /* __darwin_mach_port_t */ typedef __darwin_mach_port_t mach_port_t; #include + #include mach_port_t pthread_mach_thread_np(pthread_t); #endif /* _MACH_PORT_T */ diff --git a/src/ray/object_manager/plasma/connection.cc b/src/ray/object_manager/plasma/connection.cc index c0884184a..376c7a5bb 100644 --- a/src/ray/object_manager/plasma/connection.cc +++ b/src/ray/object_manager/plasma/connection.cc @@ -24,7 +24,8 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr &sto namespace { const std::vector GenerateEnumNames(const char *const *enum_names_ptr, - int start_index, int end_index) { + int start_index, + int end_index) { std::vector enum_names; for (int i = 0; i < start_index; ++i) { enum_names.push_back("EmptyMessageType"); @@ -44,12 +45,15 @@ const std::vector GenerateEnumNames(const char *const *enum_names_p } static const std::vector object_store_message_enum = - GenerateEnumNames(flatbuf::EnumNamesMessageType(), static_cast(MessageType::MIN), + GenerateEnumNames(flatbuf::EnumNamesMessageType(), + static_cast(MessageType::MIN), static_cast(MessageType::MAX)); } // namespace Client::Client(ray::MessageHandler &message_handler, ray::local_stream_socket &&socket) - : ray::ClientConnection(message_handler, std::move(socket), "worker", + : ray::ClientConnection(message_handler, + std::move(socket), + "worker", object_store_message_enum, static_cast(MessageType::PlasmaDisconnectClient)) {} @@ -57,10 +61,12 @@ std::shared_ptr Client::Create(PlasmaStoreMessageHandler message_handler ray::local_stream_socket &&socket) { ray::MessageHandler ray_message_handler = [message_handler](std::shared_ptr client, - int64_t message_type, const std::vector &message) { + int64_t message_type, + const std::vector &message) { Status s = message_handler( std::static_pointer_cast(client->shared_ClientConnection_from_this()), - (MessageType)message_type, message); + (MessageType)message_type, + message); if (!s.ok()) { if (!s.IsDisconnected()) { RAY_LOG(ERROR) << "Fail to process client message. " << s.ToString(); @@ -92,8 +98,13 @@ Status Client::SendFd(MEMFD_TYPE fd) { return Status::Invalid("Cannot open PID = " + std::to_string(target_pid)); } HANDLE target_handle = NULL; - bool success = DuplicateHandle(GetCurrentProcess(), fd.first, target_process, - &target_handle, 0, TRUE, DUPLICATE_SAME_ACCESS); + bool success = DuplicateHandle(GetCurrentProcess(), + fd.first, + target_process, + &target_handle, + 0, + TRUE, + DUPLICATE_SAME_ACCESS); if (!success) { // TODO(suquark): Define better error type. return Status::IOError("Fail to duplicate handle to PID = " + @@ -103,8 +114,13 @@ Status Client::SendFd(MEMFD_TYPE fd) { if (!s.ok()) { /* we failed to send the handle, and it needs cleaning up! */ HANDLE duplicated_back = NULL; - if (DuplicateHandle(target_process, fd.first, GetCurrentProcess(), &duplicated_back, - 0, FALSE, DUPLICATE_CLOSE_SOURCE)) { + if (DuplicateHandle(target_process, + fd.first, + GetCurrentProcess(), + &duplicated_back, + 0, + FALSE, + DUPLICATE_CLOSE_SOURCE)) { CloseHandle(duplicated_back); } CloseHandle(target_process); diff --git a/src/ray/object_manager/plasma/connection.h b/src/ray/object_manager/plasma/connection.h index 1302f14ba..02a1736e8 100644 --- a/src/ray/object_manager/plasma/connection.h +++ b/src/ray/object_manager/plasma/connection.h @@ -1,12 +1,11 @@ #pragma once +#include "absl/container/flat_hash_set.h" #include "ray/common/client_connection.h" #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/object_manager/plasma/compat.h" -#include "absl/container/flat_hash_set.h" - namespace plasma { namespace flatbuf { diff --git a/src/ray/object_manager/plasma/create_request_queue.cc b/src/ray/object_manager/plasma/create_request_queue.cc index 4f7fc84c3..12d3d92c6 100644 --- a/src/ray/object_manager/plasma/create_request_queue.cc +++ b/src/ray/object_manager/plasma/create_request_queue.cc @@ -36,7 +36,8 @@ uint64_t CreateRequestQueue::AddRequest(const ObjectID &object_id, return req_id; } -bool CreateRequestQueue::GetRequestResult(uint64_t req_id, PlasmaObject *result, +bool CreateRequestQueue::GetRequestResult(uint64_t req_id, + PlasmaObject *result, PlasmaError *error) { auto it = fulfilled_requests_.find(req_id); if (it == fulfilled_requests_.end()) { @@ -59,12 +60,15 @@ bool CreateRequestQueue::GetRequestResult(uint64_t req_id, PlasmaObject *result, } std::pair CreateRequestQueue::TryRequestImmediately( - const ObjectID &object_id, const std::shared_ptr &client, - const CreateObjectCallback &create_callback, size_t object_size) { + const ObjectID &object_id, + const std::shared_ptr &client, + const CreateObjectCallback &create_callback, + size_t object_size) { PlasmaObject result = {}; // Immediately fulfill it using the fallback allocator. - PlasmaError error = create_callback(/*fallback_allocator=*/true, &result, + PlasmaError error = create_callback(/*fallback_allocator=*/true, + &result, /*spilling_required=*/nullptr); return {result, error}; } @@ -119,7 +123,8 @@ Status CreateRequestQueue::ProcessRequests() { return Status::ObjectStoreFull("Waiting for grace period."); } else { // Trigger the fallback allocator. - status = ProcessRequest(/*fallback_allocator=*/true, *request_it, + status = ProcessRequest(/*fallback_allocator=*/true, + *request_it, /*spilling_required=*/nullptr); if (!status.ok()) { std::string dump = ""; diff --git a/src/ray/object_manager/plasma/create_request_queue.h b/src/ray/object_manager/plasma/create_request_queue.h index c1dc6e316..740ef7eed 100644 --- a/src/ray/object_manager/plasma/create_request_queue.h +++ b/src/ray/object_manager/plasma/create_request_queue.h @@ -19,7 +19,6 @@ #include #include "absl/container/flat_hash_map.h" - #include "ray/common/status.h" #include "ray/object_manager/common.h" #include "ray/object_manager/plasma/common.h" @@ -91,8 +90,10 @@ class CreateRequestQueue { /// if there are other requests queued or there is not enough space left in /// the object store, this will return an out-of-memory error. std::pair TryRequestImmediately( - const ObjectID &object_id, const std::shared_ptr &client, - const CreateObjectCallback &create_callback, size_t object_size); + const ObjectID &object_id, + const std::shared_ptr &client, + const CreateObjectCallback &create_callback, + size_t object_size); /// Process requests in the queue. /// @@ -115,9 +116,11 @@ class CreateRequestQueue { private: struct CreateRequest { - CreateRequest(const ObjectID &object_id, uint64_t request_id, + CreateRequest(const ObjectID &object_id, + uint64_t request_id, const std::shared_ptr &client, - CreateObjectCallback create_callback, size_t object_size) + CreateObjectCallback create_callback, + size_t object_size) : object_id(object_id), request_id(request_id), client(client), @@ -149,7 +152,8 @@ class CreateRequestQueue { /// Process a single request. Sets the request's error result to the error /// returned by the request handler inside. Returns OK if the request can be /// finished. - Status ProcessRequest(bool fallback_allocator, std::unique_ptr &request, + Status ProcessRequest(bool fallback_allocator, + std::unique_ptr &request, bool *spilling_required); /// Finish a queued request and remove it from the queue. diff --git a/src/ray/object_manager/plasma/dlmalloc.cc b/src/ray/object_manager/plasma/dlmalloc.cc index c90857ca0..9d95388ce 100644 --- a/src/ray/object_manager/plasma/dlmalloc.cc +++ b/src/ray/object_manager/plasma/dlmalloc.cc @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -#include "ray/object_manager/plasma/malloc.h" - #include + #include +#include "ray/object_manager/plasma/malloc.h" + #ifdef __linux__ #ifndef _GNU_SOURCE #define _GNU_SOURCE /* Turns on fallocate() definition */ @@ -119,9 +120,12 @@ DLMallocConfig dlmalloc_config; #ifdef _WIN32 void create_and_mmap_buffer(int64_t size, void **pointer, HANDLE *handle) { - *handle = CreateFileMapping(INVALID_HANDLE_VALUE, NULL, PAGE_READWRITE, + *handle = CreateFileMapping(INVALID_HANDLE_VALUE, + NULL, + PAGE_READWRITE, (DWORD)((uint64_t)size >> (CHAR_BIT * sizeof(DWORD))), - (DWORD)(uint64_t)size, NULL); + (DWORD)(uint64_t)size, + NULL); RAY_CHECK(*handle != nullptr) << "CreateFileMapping() failed. GetLastError() = " << GetLastError(); *pointer = MapViewOfFile(*handle, FILE_MAP_ALL_ACCESS, 0, 0, (size_t)size); @@ -297,7 +301,8 @@ bool IsOutsideInitialAllocation(void *p) { } void SetDLMallocConfig(const std::string &plasma_directory, - const std::string &fallback_directory, bool hugepage_enabled, + const std::string &fallback_directory, + bool hugepage_enabled, bool fallback_enabled) { dlmalloc_config.hugepages_enabled = hugepage_enabled; dlmalloc_config.directory = plasma_directory; diff --git a/src/ray/object_manager/plasma/eviction_policy.cc b/src/ray/object_manager/plasma/eviction_policy.cc index 77e283e12..5e91e53a7 100644 --- a/src/ray/object_manager/plasma/eviction_policy.cc +++ b/src/ray/object_manager/plasma/eviction_policy.cc @@ -16,11 +16,12 @@ // under the License. #include "ray/object_manager/plasma/eviction_policy.h" -#include "ray/object_manager/plasma/plasma_allocator.h" #include #include +#include "ray/object_manager/plasma/plasma_allocator.h" + namespace plasma { void LRUCache::Add(const ObjectID &key, int64_t size) { diff --git a/src/ray/object_manager/plasma/fling.cc b/src/ray/object_manager/plasma/fling.cc index 39b7cffeb..358311713 100644 --- a/src/ray/object_manager/plasma/fling.cc +++ b/src/ray/object_manager/plasma/fling.cc @@ -20,14 +20,13 @@ #include #include - -#include "ray/util/logging.h" - #include #include #include #include +#include "ray/util/logging.h" + // This is necessary for Mac OS X, see http://www.apuebook.com/faqs2e.html // (10). #if !defined(CMSG_SPACE) && !defined(CMSG_LEN) diff --git a/src/ray/object_manager/plasma/get_request_queue.cc b/src/ray/object_manager/plasma/get_request_queue.cc index 015710ce7..6f6fb713c 100644 --- a/src/ray/object_manager/plasma/get_request_queue.cc +++ b/src/ray/object_manager/plasma/get_request_queue.cc @@ -18,7 +18,8 @@ namespace plasma { GetRequest::GetRequest(instrumented_io_context &io_context, const std::shared_ptr &client, - const std::vector &object_ids, bool is_from_worker, + const std::vector &object_ids, + bool is_from_worker, int64_t num_unique_objects_to_wait_for) : client(client), object_ids(object_ids.begin(), object_ids.end()), @@ -51,11 +52,12 @@ bool GetRequest::IsRemoved() const { return is_removed_; } void GetRequestQueue::AddRequest(const std::shared_ptr &client, const std::vector &object_ids, - int64_t timeout_ms, bool is_from_worker) { + int64_t timeout_ms, + bool is_from_worker) { const absl::flat_hash_set unique_ids(object_ids.begin(), object_ids.end()); // Create a get request for this object. - auto get_request = std::make_shared(io_context_, client, object_ids, - is_from_worker, unique_ids.size()); + auto get_request = std::make_shared( + io_context_, client, object_ids, is_from_worker, unique_ids.size()); for (const auto &object_id : unique_ids) { // Check if this object is already present // locally. If so, record that the object is being used and mark it as accounted for. diff --git a/src/ray/object_manager/plasma/get_request_queue.h b/src/ray/object_manager/plasma/get_request_queue.h index 5e014499d..c6b9653df 100644 --- a/src/ray/object_manager/plasma/get_request_queue.h +++ b/src/ray/object_manager/plasma/get_request_queue.h @@ -29,7 +29,8 @@ using AllObjectReadyCallback = struct GetRequest { GetRequest(instrumented_io_context &io_context, const std::shared_ptr &client, - const std::vector &object_ids, bool is_from_worker, + const std::vector &object_ids, + bool is_from_worker, int64_t num_unique_objects_to_wait_for); /// The client that called get. std::shared_ptr client; @@ -89,7 +90,8 @@ class GetRequestQueue { /// satisfied. \param all_objects_callback the callback function called when all objects /// has been satisfied. void AddRequest(const std::shared_ptr &client, - const std::vector &object_ids, int64_t timeout_ms, + const std::vector &object_ids, + int64_t timeout_ms, bool is_from_worker); /// Remove all of the GetRequests for a given client. diff --git a/src/ray/object_manager/plasma/malloc.cc b/src/ray/object_manager/plasma/malloc.cc index 77f52bcca..ce7b39caf 100644 --- a/src/ray/object_manager/plasma/malloc.cc +++ b/src/ray/object_manager/plasma/malloc.cc @@ -33,7 +33,9 @@ static ptrdiff_t pointer_distance(void const *pfrom, void const *pto) { return (unsigned char const *)pto - (unsigned char const *)pfrom; } -bool GetMallocMapinfo(const void *const addr, MEMFD_TYPE *fd, int64_t *map_size, +bool GetMallocMapinfo(const void *const addr, + MEMFD_TYPE *fd, + int64_t *map_size, ptrdiff_t *offset) { // TODO(rshin): Implement a more efficient search through mmap_records. for (const auto &entry : mmap_records) { diff --git a/src/ray/object_manager/plasma/malloc.h b/src/ray/object_manager/plasma/malloc.h index 06b181e4e..0ad41f1c0 100644 --- a/src/ray/object_manager/plasma/malloc.h +++ b/src/ray/object_manager/plasma/malloc.h @@ -21,6 +21,7 @@ #include #include + #include "ray/object_manager/plasma/compat.h" namespace plasma { @@ -43,7 +44,9 @@ extern std::unordered_map mmap_records; /// private function, only used by PlasmaAllocator namespace internal { -bool GetMallocMapinfo(const void *const addr, MEMFD_TYPE *fd, int64_t *map_length, +bool GetMallocMapinfo(const void *const addr, + MEMFD_TYPE *fd, + int64_t *map_length, ptrdiff_t *offset); } } // namespace plasma diff --git a/src/ray/object_manager/plasma/object_lifecycle_manager.cc b/src/ray/object_manager/plasma/object_lifecycle_manager.cc index fee8b486e..365a1be6b 100644 --- a/src/ray/object_manager/plasma/object_lifecycle_manager.cc +++ b/src/ray/object_manager/plasma/object_lifecycle_manager.cc @@ -32,7 +32,8 @@ ObjectLifecycleManager::ObjectLifecycleManager( stats_collector_() {} std::pair ObjectLifecycleManager::CreateObject( - const ray::ObjectInfo &object_info, plasma::flatbuf::ObjectSource source, + const ray::ObjectInfo &object_info, + plasma::flatbuf::ObjectSource source, bool fallback_allocator) { RAY_LOG(DEBUG) << "attempting to create object " << object_info.object_id << " size " << object_info.data_size; @@ -171,7 +172,8 @@ std::string ObjectLifecycleManager::EvictionPolicyDebugString() const { } const LocalObject *ObjectLifecycleManager::CreateObjectInternal( - const ray::ObjectInfo &object_info, plasma::flatbuf::ObjectSource source, + const ray::ObjectInfo &object_info, + plasma::flatbuf::ObjectSource source, bool allow_fallback_allocation) { // Try to evict objects until there is enough space. // NOTE(ekl) if we can't achieve this after a number of retries, it's @@ -277,7 +279,8 @@ void ObjectLifecycleManager::GetDebugDump(std::stringstream &buffer) const { // For test only. ObjectLifecycleManager::ObjectLifecycleManager( - std::unique_ptr store, std::unique_ptr eviction_policy, + std::unique_ptr store, + std::unique_ptr eviction_policy, ray::DeleteObjectCallback delete_object_callback) : object_store_(std::move(store)), eviction_policy_(std::move(eviction_policy)), diff --git a/src/ray/object_manager/plasma/object_lifecycle_manager.h b/src/ray/object_manager/plasma/object_lifecycle_manager.h index e1a007d16..1e2a7d34f 100644 --- a/src/ray/object_manager/plasma/object_lifecycle_manager.h +++ b/src/ray/object_manager/plasma/object_lifecycle_manager.h @@ -42,7 +42,8 @@ class IObjectLifecycleManager { /// - nullptr and error message, including ObjectExists/OutOfMemory /// TODO(scv119): use RAII instead of pointer for returned object. virtual std::pair CreateObject( - const ray::ObjectInfo &object_info, plasma::flatbuf::ObjectSource source, + const ray::ObjectInfo &object_info, + plasma::flatbuf::ObjectSource source, bool fallback_allocator) = 0; /// Get object by id. @@ -103,7 +104,8 @@ class ObjectLifecycleManager : public IObjectLifecycleManager { ray::DeleteObjectCallback delete_object_callback); std::pair CreateObject( - const ray::ObjectInfo &object_info, plasma::flatbuf::ObjectSource source, + const ray::ObjectInfo &object_info, + plasma::flatbuf::ObjectSource source, bool fallback_allocator) override; const LocalObject *GetObject(const ObjectID &object_id) const override; diff --git a/src/ray/object_manager/plasma/plasma_allocator.cc b/src/ray/object_manager/plasma/plasma_allocator.cc index 7b7eaa3e8..015ef1717 100644 --- a/src/ray/object_manager/plasma/plasma_allocator.cc +++ b/src/ray/object_manager/plasma/plasma_allocator.cc @@ -15,18 +15,19 @@ // specific language governing permissions and limitations // under the License. -#include "ray/common/ray_config.h" -#include "ray/util/logging.h" - -#include "ray/object_manager/plasma/malloc.h" #include "ray/object_manager/plasma/plasma_allocator.h" +#include "ray/common/ray_config.h" +#include "ray/object_manager/plasma/malloc.h" +#include "ray/util/logging.h" + namespace plasma { namespace internal { bool IsOutsideInitialAllocation(void *ptr); void SetDLMallocConfig(const std::string &plasma_directory, - const std::string &fallback_directory, bool hugepage_enabled, + const std::string &fallback_directory, + bool hugepage_enabled, bool fallback_enabled); } // namespace internal @@ -59,12 +60,15 @@ const int64_t kDlMallocReserved = 256 * sizeof(size_t); PlasmaAllocator::PlasmaAllocator(const std::string &plasma_directory, const std::string &fallback_directory, - bool hugepage_enabled, int64_t footprint_limit) + bool hugepage_enabled, + int64_t footprint_limit) : kFootprintLimit(footprint_limit), kAlignment(kAllocationAlignment), allocated_(0), fallback_allocated_(0) { - internal::SetDLMallocConfig(plasma_directory, fallback_directory, hugepage_enabled, + internal::SetDLMallocConfig(plasma_directory, + fallback_directory, + hugepage_enabled, /*fallback_enabled=*/true); RAY_CHECK(kFootprintLimit > kDlMallocReserved) << "Footprint limit has to be greater than " << kDlMallocReserved; @@ -134,8 +138,12 @@ absl::optional PlasmaAllocator::BuildAllocation(void *addr, size_t s ptrdiff_t offset; if (internal::GetMallocMapinfo(addr, &fd, &mmap_size, &offset)) { - return Allocation(addr, static_cast(size), std::move(fd), offset, - 0 /* device_number*/, mmap_size); + return Allocation(addr, + static_cast(size), + std::move(fd), + offset, + 0 /* device_number*/, + mmap_size); } return absl::nullopt; } diff --git a/src/ray/object_manager/plasma/plasma_allocator.h b/src/ray/object_manager/plasma/plasma_allocator.h index 3cf982c81..880d901dc 100644 --- a/src/ray/object_manager/plasma/plasma_allocator.h +++ b/src/ray/object_manager/plasma/plasma_allocator.h @@ -20,9 +20,9 @@ #include #include #include -#include "ray/object_manager/plasma/allocator.h" #include "absl/types/optional.h" +#include "ray/object_manager/plasma/allocator.h" #include "ray/object_manager/plasma/common.h" namespace plasma { @@ -41,7 +41,8 @@ namespace plasma { class PlasmaAllocator : public IAllocator { public: PlasmaAllocator(const std::string &plasma_directory, - const std::string &fallback_directory, bool hugepage_enabled, + const std::string &fallback_directory, + bool hugepage_enabled, int64_t footprint_limit); /// On linux, it allocates memory from a pre-mmapped file from /dev/shm. diff --git a/src/ray/object_manager/plasma/protocol.cc b/src/ray/object_manager/plasma/protocol.cc index a06eab3d0..181b263d4 100644 --- a/src/ray/object_manager/plasma/protocol.cc +++ b/src/ray/object_manager/plasma/protocol.cc @@ -58,7 +58,8 @@ inline T *MakeNonNull(T *maybe_null) { } flatbuffers::Offset>> -ToFlatbuffer(flatbuffers::FlatBufferBuilder *fbb, const ObjectID *object_ids, +ToFlatbuffer(flatbuffers::FlatBufferBuilder *fbb, + const ObjectID *object_ids, int64_t num_objects) { std::vector> results; for (int64_t i = 0; i < num_objects; i++) { @@ -84,7 +85,8 @@ flatbuffers::Offset> ToFlatbuffer( } Status PlasmaReceive(const std::shared_ptr &store_conn, - MessageType message_type, std::vector *buffer) { + MessageType message_type, + std::vector *buffer) { if (!store_conn) { return Status::IOError("Connection is closed."); } @@ -104,7 +106,8 @@ void ToVector(const Data &request, std::vector *out, const Getter &getter) { } template -void ConvertToVector(const FlatbufferVectorPointer fbvector, std::vector *out, +void ConvertToVector(const FlatbufferVectorPointer fbvector, + std::vector *out, const Converter &converter) { out->clear(); out->reserve(fbvector->size()); @@ -114,25 +117,29 @@ void ConvertToVector(const FlatbufferVectorPointer fbvector, std::vector *out } template -Status PlasmaSend(const std::shared_ptr &store_conn, MessageType message_type, - flatbuffers::FlatBufferBuilder *fbb, const Message &message) { +Status PlasmaSend(const std::shared_ptr &store_conn, + MessageType message_type, + flatbuffers::FlatBufferBuilder *fbb, + const Message &message) { if (!store_conn) { return Status::IOError("Connection is closed."); } fbb->Finish(message); - return store_conn->WriteMessage(static_cast(message_type), fbb->GetSize(), - fbb->GetBufferPointer()); + return store_conn->WriteMessage( + static_cast(message_type), fbb->GetSize(), fbb->GetBufferPointer()); } template -Status PlasmaSend(const std::shared_ptr &client, MessageType message_type, - flatbuffers::FlatBufferBuilder *fbb, const Message &message) { +Status PlasmaSend(const std::shared_ptr &client, + MessageType message_type, + flatbuffers::FlatBufferBuilder *fbb, + const Message &message) { if (!client) { return Status::IOError("Connection is closed."); } fbb->Finish(message); - return client->WriteMessage(static_cast(message_type), fbb->GetSize(), - fbb->GetBufferPointer()); + return client->WriteMessage( + static_cast(message_type), fbb->GetSize(), fbb->GetBufferPointer()); } Status PlasmaErrorStatus(fb::PlasmaError plasma_error) { @@ -180,29 +187,43 @@ Status ReadGetDebugStringReply(uint8_t *data, size_t size, std::string *debug_st // Create messages. Status SendCreateRetryRequest(const std::shared_ptr &store_conn, - ObjectID object_id, uint64_t request_id) { + ObjectID object_id, + uint64_t request_id) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaCreateRetryRequest( fbb, fbb.CreateString(object_id.Binary()), request_id); return PlasmaSend(store_conn, MessageType::PlasmaCreateRetryRequest, &fbb, message); } -Status SendCreateRequest(const std::shared_ptr &store_conn, ObjectID object_id, - const ray::rpc::Address &owner_address, int64_t data_size, - int64_t metadata_size, flatbuf::ObjectSource source, - int device_num, bool try_immediately) { +Status SendCreateRequest(const std::shared_ptr &store_conn, + ObjectID object_id, + const ray::rpc::Address &owner_address, + int64_t data_size, + int64_t metadata_size, + flatbuf::ObjectSource source, + int device_num, + bool try_immediately) { flatbuffers::FlatBufferBuilder fbb; - auto message = fb::CreatePlasmaCreateRequest( - fbb, fbb.CreateString(object_id.Binary()), - fbb.CreateString(owner_address.raylet_id()), - fbb.CreateString(owner_address.ip_address()), owner_address.port(), - fbb.CreateString(owner_address.worker_id()), data_size, metadata_size, source, - device_num, try_immediately); + auto message = + fb::CreatePlasmaCreateRequest(fbb, + fbb.CreateString(object_id.Binary()), + fbb.CreateString(owner_address.raylet_id()), + fbb.CreateString(owner_address.ip_address()), + owner_address.port(), + fbb.CreateString(owner_address.worker_id()), + data_size, + metadata_size, + source, + device_num, + try_immediately); return PlasmaSend(store_conn, MessageType::PlasmaCreateRequest, &fbb, message); } -void ReadCreateRequest(uint8_t *data, size_t size, ray::ObjectInfo *object_info, - flatbuf::ObjectSource *source, int *device_num) { +void ReadCreateRequest(uint8_t *data, + size_t size, + ray::ObjectInfo *object_info, + flatbuf::ObjectSource *source, + int *device_num) { RAY_DCHECK(data); auto message = flatbuffers::GetRoot(data); RAY_DCHECK(VerifyFlatbuffer(message, data, size)); @@ -219,7 +240,8 @@ void ReadCreateRequest(uint8_t *data, size_t size, ray::ObjectInfo *object_info, } Status SendUnfinishedCreateReply(const std::shared_ptr &client, - ObjectID object_id, uint64_t retry_with_request_id) { + ObjectID object_id, + uint64_t retry_with_request_id) { flatbuffers::FlatBufferBuilder fbb; auto object_string = fbb.CreateString(object_id.Binary()); fb::PlasmaCreateReplyBuilder crb(fbb); @@ -229,12 +251,18 @@ Status SendUnfinishedCreateReply(const std::shared_ptr &client, return PlasmaSend(client, MessageType::PlasmaCreateReply, &fbb, message); } -Status SendCreateReply(const std::shared_ptr &client, ObjectID object_id, - const PlasmaObject &object, PlasmaError error_code) { +Status SendCreateReply(const std::shared_ptr &client, + ObjectID object_id, + const PlasmaObject &object, + PlasmaError error_code) { flatbuffers::FlatBufferBuilder fbb; - PlasmaObjectSpec plasma_object( - FD2INT(object.store_fd.first), object.store_fd.second, object.data_offset, - object.data_size, object.metadata_offset, object.metadata_size, object.device_num); + PlasmaObjectSpec plasma_object(FD2INT(object.store_fd.first), + object.store_fd.second, + object.data_offset, + object.data_size, + object.metadata_offset, + object.metadata_size, + object.device_num); auto object_string = fbb.CreateString(object_id.Binary()); fb::PlasmaCreateReplyBuilder crb(fbb); crb.add_error(static_cast(error_code)); @@ -251,9 +279,13 @@ Status SendCreateReply(const std::shared_ptr &client, ObjectID object_id return PlasmaSend(client, MessageType::PlasmaCreateReply, &fbb, message); } -Status ReadCreateReply(uint8_t *data, size_t size, ObjectID *object_id, - uint64_t *retry_with_request_id, PlasmaObject *object, - MEMFD_TYPE *store_fd, int64_t *mmap_size) { +Status ReadCreateReply(uint8_t *data, + size_t size, + ObjectID *object_id, + uint64_t *retry_with_request_id, + PlasmaObject *object, + MEMFD_TYPE *store_fd, + int64_t *mmap_size) { RAY_DCHECK(data); auto message = flatbuffers::GetRoot(data); RAY_DCHECK(VerifyFlatbuffer(message, data, size)); @@ -324,7 +356,8 @@ Status ReadSealRequest(uint8_t *data, size_t size, ObjectID *object_id) { return Status::OK(); } -Status SendSealReply(const std::shared_ptr &client, ObjectID object_id, +Status SendSealReply(const std::shared_ptr &client, + ObjectID object_id, PlasmaError error) { flatbuffers::FlatBufferBuilder fbb; auto message = @@ -358,7 +391,8 @@ Status ReadReleaseRequest(uint8_t *data, size_t size, ObjectID *object_id) { return Status::OK(); } -Status SendReleaseReply(const std::shared_ptr &client, ObjectID object_id, +Status SendReleaseReply(const std::shared_ptr &client, + ObjectID object_id, PlasmaError error) { flatbuffers::FlatBufferBuilder fbb; auto message = @@ -380,7 +414,8 @@ Status SendDeleteRequest(const std::shared_ptr &store_conn, const std::vector &object_ids) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaDeleteRequest( - fbb, static_cast(object_ids.size()), + fbb, + static_cast(object_ids.size()), ToFlatbuffer(&fbb, &object_ids[0], object_ids.size())); return PlasmaSend(store_conn, MessageType::PlasmaDeleteRequest, &fbb, message); } @@ -404,14 +439,17 @@ Status SendDeleteReply(const std::shared_ptr &client, RAY_DCHECK(object_ids.size() == errors.size()); flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaDeleteReply( - fbb, static_cast(object_ids.size()), + fbb, + static_cast(object_ids.size()), ToFlatbuffer(&fbb, &object_ids[0], object_ids.size()), fbb.CreateVector(MakeNonNull(reinterpret_cast(errors.data())), object_ids.size())); return PlasmaSend(client, MessageType::PlasmaDeleteReply, &fbb, message); } -Status ReadDeleteReply(uint8_t *data, size_t size, std::vector *object_ids, +Status ReadDeleteReply(uint8_t *data, + size_t size, + std::vector *object_ids, std::vector *errors) { using fb::PlasmaDeleteReply; @@ -447,15 +485,18 @@ Status ReadContainsRequest(uint8_t *data, size_t size, ObjectID *object_id) { return Status::OK(); } -Status SendContainsReply(const std::shared_ptr &client, ObjectID object_id, +Status SendContainsReply(const std::shared_ptr &client, + ObjectID object_id, bool has_object) { flatbuffers::FlatBufferBuilder fbb; - auto message = fb::CreatePlasmaContainsReply(fbb, fbb.CreateString(object_id.Binary()), - has_object); + auto message = fb::CreatePlasmaContainsReply( + fbb, fbb.CreateString(object_id.Binary()), has_object); return PlasmaSend(client, MessageType::PlasmaContainsReply, &fbb, message); } -Status ReadContainsReply(uint8_t *data, size_t size, ObjectID *object_id, +Status ReadContainsReply(uint8_t *data, + size_t size, + ObjectID *object_id, bool *has_object) { RAY_DCHECK(data); auto message = flatbuffers::GetRoot(data); @@ -522,7 +563,9 @@ Status ReadEvictReply(uint8_t *data, size_t size, int64_t &num_bytes) { // Get messages. Status SendGetRequest(const std::shared_ptr &store_conn, - const ObjectID *object_ids, int64_t num_objects, int64_t timeout_ms, + const ObjectID *object_ids, + int64_t num_objects, + int64_t timeout_ms, bool is_from_worker) { flatbuffers::FlatBufferBuilder fbb; auto message = fb::CreatePlasmaGetRequest( @@ -530,8 +573,11 @@ Status SendGetRequest(const std::shared_ptr &store_conn, return PlasmaSend(store_conn, MessageType::PlasmaGetRequest, &fbb, message); } -Status ReadGetRequest(uint8_t *data, size_t size, std::vector &object_ids, - int64_t *timeout_ms, bool *is_from_worker) { +Status ReadGetRequest(uint8_t *data, + size_t size, + std::vector &object_ids, + int64_t *timeout_ms, + bool *is_from_worker) { RAY_DCHECK(data); auto message = flatbuffers::GetRoot(data); RAY_DCHECK(VerifyFlatbuffer(message, data, size)); @@ -544,9 +590,11 @@ Status ReadGetRequest(uint8_t *data, size_t size, std::vector &object_ return Status::OK(); } -Status SendGetReply(const std::shared_ptr &client, ObjectID object_ids[], +Status SendGetReply(const std::shared_ptr &client, + ObjectID object_ids[], absl::flat_hash_map &plasma_objects, - int64_t num_objects, const std::vector &store_fds, + int64_t num_objects, + const std::vector &store_fds, const std::vector &mmap_sizes) { flatbuffers::FlatBufferBuilder fbb; std::vector objects; @@ -558,9 +606,12 @@ Status SendGetReply(const std::shared_ptr &client, ObjectID object_ids[] << " data_size: " << object.data_size << " metadata_size: " << object.metadata_size; objects.push_back(PlasmaObjectSpec(FD2INT(object.store_fd.first), - object.store_fd.second, object.data_offset, - object.data_size, object.metadata_offset, - object.metadata_size, object.device_num)); + object.store_fd.second, + object.data_offset, + object.data_size, + object.metadata_offset, + object.metadata_size, + object.device_num)); } std::vector store_fds_as_int; std::vector unique_fd_ids; @@ -569,7 +620,8 @@ Status SendGetReply(const std::shared_ptr &client, ObjectID object_ids[] unique_fd_ids.push_back(store_fd.second); } auto message = fb::CreatePlasmaGetReply( - fbb, ToFlatbuffer(&fbb, object_ids, num_objects), + fbb, + ToFlatbuffer(&fbb, object_ids, num_objects), fbb.CreateVectorOfStructs(MakeNonNull(objects.data()), num_objects), fbb.CreateVector(MakeNonNull(store_fds_as_int.data()), store_fds_as_int.size()), fbb.CreateVector(MakeNonNull(unique_fd_ids.data()), unique_fd_ids.size()), @@ -578,8 +630,11 @@ Status SendGetReply(const std::shared_ptr &client, ObjectID object_ids[] return PlasmaSend(client, MessageType::PlasmaGetReply, &fbb, message); } -Status ReadGetReply(uint8_t *data, size_t size, ObjectID object_ids[], - PlasmaObject plasma_objects[], int64_t num_objects, +Status ReadGetReply(uint8_t *data, + size_t size, + ObjectID object_ids[], + PlasmaObject plasma_objects[], + int64_t num_objects, std::vector &store_fds, std::vector &mmap_sizes) { RAY_DCHECK(data); diff --git a/src/ray/object_manager/plasma/protocol.h b/src/ray/object_manager/plasma/protocol.h index 2364d9805..ce965103c 100644 --- a/src/ray/object_manager/plasma/protocol.h +++ b/src/ray/object_manager/plasma/protocol.h @@ -50,7 +50,8 @@ bool VerifyFlatbuffer(T *object, uint8_t *data, size_t size) { } flatbuffers::Offset>> -ToFlatbuffer(flatbuffers::FlatBufferBuilder *fbb, const ObjectID *object_ids, +ToFlatbuffer(flatbuffers::FlatBufferBuilder *fbb, + const ObjectID *object_ids, int64_t num_objects); flatbuffers::Offset>> @@ -63,7 +64,8 @@ flatbuffers::Offset> ToFlatbuffer( /* Plasma receive message. */ Status PlasmaReceive(const std::shared_ptr &store_conn, - MessageType message_type, std::vector *buffer); + MessageType message_type, + std::vector *buffer); /* Debug string messages. */ @@ -77,25 +79,40 @@ Status ReadGetDebugStringReply(uint8_t *data, size_t size, std::string *debug_st /* Plasma Create message functions. */ Status SendCreateRetryRequest(const std::shared_ptr &store_conn, - ObjectID object_id, uint64_t request_id); + ObjectID object_id, + uint64_t request_id); -Status SendCreateRequest(const std::shared_ptr &store_conn, ObjectID object_id, - const ray::rpc::Address &owner_address, int64_t data_size, - int64_t metadata_size, flatbuf::ObjectSource source, - int device_num, bool try_immediately); +Status SendCreateRequest(const std::shared_ptr &store_conn, + ObjectID object_id, + const ray::rpc::Address &owner_address, + int64_t data_size, + int64_t metadata_size, + flatbuf::ObjectSource source, + int device_num, + bool try_immediately); -void ReadCreateRequest(uint8_t *data, size_t size, ray::ObjectInfo *object_info, - flatbuf::ObjectSource *source, int *device_num); +void ReadCreateRequest(uint8_t *data, + size_t size, + ray::ObjectInfo *object_info, + flatbuf::ObjectSource *source, + int *device_num); Status SendUnfinishedCreateReply(const std::shared_ptr &client, - ObjectID object_id, uint64_t retry_with_request_id); + ObjectID object_id, + uint64_t retry_with_request_id); -Status SendCreateReply(const std::shared_ptr &client, ObjectID object_id, - const PlasmaObject &object, PlasmaError error); +Status SendCreateReply(const std::shared_ptr &client, + ObjectID object_id, + const PlasmaObject &object, + PlasmaError error); -Status ReadCreateReply(uint8_t *data, size_t size, ObjectID *object_id, - uint64_t *retry_with_request_id, PlasmaObject *object, - MEMFD_TYPE *store_fd, int64_t *mmap_size); +Status ReadCreateReply(uint8_t *data, + size_t size, + ObjectID *object_id, + uint64_t *retry_with_request_id, + PlasmaObject *object, + MEMFD_TYPE *store_fd, + int64_t *mmap_size); Status SendAbortRequest(const std::shared_ptr &store_conn, ObjectID object_id); @@ -111,7 +128,8 @@ Status SendSealRequest(const std::shared_ptr &store_conn, ObjectID ob Status ReadSealRequest(uint8_t *data, size_t size, ObjectID *object_id); -Status SendSealReply(const std::shared_ptr &client, ObjectID object_id, +Status SendSealReply(const std::shared_ptr &client, + ObjectID object_id, PlasmaError error); Status ReadSealReply(uint8_t *data, size_t size, ObjectID *object_id); @@ -119,20 +137,31 @@ Status ReadSealReply(uint8_t *data, size_t size, ObjectID *object_id); /* Plasma Get message functions. */ Status SendGetRequest(const std::shared_ptr &store_conn, - const ObjectID *object_ids, int64_t num_objects, int64_t timeout_ms, + const ObjectID *object_ids, + int64_t num_objects, + int64_t timeout_ms, bool is_from_worker); -Status ReadGetRequest(uint8_t *data, size_t size, std::vector &object_ids, - int64_t *timeout_ms, bool *is_from_worker); +Status ReadGetRequest(uint8_t *data, + size_t size, + std::vector &object_ids, + int64_t *timeout_ms, + bool *is_from_worker); -Status SendGetReply(const std::shared_ptr &client, ObjectID object_ids[], +Status SendGetReply(const std::shared_ptr &client, + ObjectID object_ids[], absl::flat_hash_map &plasma_objects, - int64_t num_objects, const std::vector &store_fds, + int64_t num_objects, + const std::vector &store_fds, const std::vector &mmap_sizes); -Status ReadGetReply(uint8_t *data, size_t size, ObjectID object_ids[], - PlasmaObject plasma_objects[], int64_t num_objects, - std::vector &store_fds, std::vector &mmap_sizes); +Status ReadGetReply(uint8_t *data, + size_t size, + ObjectID object_ids[], + PlasmaObject plasma_objects[], + int64_t num_objects, + std::vector &store_fds, + std::vector &mmap_sizes); /* Plasma Release message functions. */ @@ -141,7 +170,8 @@ Status SendReleaseRequest(const std::shared_ptr &store_conn, Status ReadReleaseRequest(uint8_t *data, size_t size, ObjectID *object_id); -Status SendReleaseReply(const std::shared_ptr &client, ObjectID object_id, +Status SendReleaseReply(const std::shared_ptr &client, + ObjectID object_id, PlasmaError error); Status ReadReleaseReply(uint8_t *data, size_t size, ObjectID *object_id); @@ -157,7 +187,9 @@ Status SendDeleteReply(const std::shared_ptr &client, const std::vector &object_ids, const std::vector &errors); -Status ReadDeleteReply(uint8_t *data, size_t size, std::vector *object_ids, +Status ReadDeleteReply(uint8_t *data, + size_t size, + std::vector *object_ids, std::vector *errors); /* Plasma Contains message functions. */ @@ -167,10 +199,13 @@ Status SendContainsRequest(const std::shared_ptr &store_conn, Status ReadContainsRequest(uint8_t *data, size_t size, ObjectID *object_id); -Status SendContainsReply(const std::shared_ptr &client, ObjectID object_id, +Status SendContainsReply(const std::shared_ptr &client, + ObjectID object_id, bool has_object); -Status ReadContainsReply(uint8_t *data, size_t size, ObjectID *object_id, +Status ReadContainsReply(uint8_t *data, + size_t size, + ObjectID *object_id, bool *has_object); /* Plasma Connect message functions. */ diff --git a/src/ray/object_manager/plasma/store.cc b/src/ray/object_manager/plasma/store.cc index 6430cf88b..5c42c816e 100644 --- a/src/ray/object_manager/plasma/store.cc +++ b/src/ray/object_manager/plasma/store.cc @@ -69,8 +69,10 @@ ray::ObjectID GetCreateRequestObjectId(const std::vector &message) { } } // namespace -PlasmaStore::PlasmaStore(instrumented_io_context &main_service, IAllocator &allocator, - const std::string &socket_name, uint32_t delay_on_oom_ms, +PlasmaStore::PlasmaStore(instrumented_io_context &main_service, + IAllocator &allocator, + const std::string &socket_name, + uint32_t delay_on_oom_ms, float object_spilling_threshold, ray::SpillObjectsCallback spill_objects_callback, std::function object_store_full_callback, @@ -88,7 +90,8 @@ PlasmaStore::PlasmaStore(instrumented_io_context &main_service, IAllocator &allo object_spilling_threshold_(object_spilling_threshold), create_request_queue_( /*oom_grace_period_s=*/RayConfig::instance().oom_grace_period_s(), - spill_objects_callback, object_store_full_callback, + spill_objects_callback, + object_store_full_callback, /*get_time=*/ []() { return absl::GetCurrentTimeNanos(); }, // absl can't check thread safety for lambda @@ -98,7 +101,8 @@ PlasmaStore::PlasmaStore(instrumented_io_context &main_service, IAllocator &allo }), total_consumed_bytes_(0), get_request_queue_( - io_context_, object_lifecycle_mgr_, + io_context_, + object_lifecycle_mgr_, // absl failed to check thread safety for lambda [this](const ObjectID &object_id, const auto &request) ABSL_NO_THREAD_SAFETY_ANALYSIS { @@ -181,7 +185,8 @@ PlasmaError PlasmaStore::HandleCreateObjectRequest(const std::shared_ptr PlasmaError PlasmaStore::CreateObject(const ray::ObjectInfo &object_info, fb::ObjectSource source, const std::shared_ptr &client, - bool fallback_allocator, PlasmaObject *result) { + bool fallback_allocator, + PlasmaObject *result) { auto pair = object_lifecycle_mgr_.CreateObject(object_info, source, fallback_allocator); auto entry = pair.first; auto error = pair.second; @@ -219,8 +224,11 @@ void PlasmaStore::ReturnFromGet(const std::shared_ptr &get_request) } // Send the get reply to the client. Status s = SendGetReply(std::dynamic_pointer_cast(get_request->client), - &get_request->object_ids[0], get_request->objects, - get_request->object_ids.size(), store_fds, mmap_sizes); + &get_request->object_ids[0], + get_request->objects, + get_request->object_ids.size(), + store_fds, + mmap_sizes); // If we successfully sent the get reply message to the client, then also send // the file descriptors. if (s.ok()) { @@ -239,7 +247,8 @@ void PlasmaStore::ReturnFromGet(const std::shared_ptr &get_request) void PlasmaStore::ProcessGetRequest(const std::shared_ptr &client, const std::vector &object_ids, - int64_t timeout_ms, bool is_from_worker) { + int64_t timeout_ms, + bool is_from_worker) { get_request_queue_.AddRequest(client, object_ids, timeout_ms, is_from_worker); } @@ -357,13 +366,14 @@ Status PlasmaStore::ProcessMessage(const std::shared_ptr &client, const size_t object_size = request->data_size() + request->metadata_size(); // absl failed analyze mutex safety for lambda - auto handle_create = [this, client, message]( - bool fallback_allocator, PlasmaObject *result, - bool *spilling_required) ABSL_NO_THREAD_SAFETY_ANALYSIS { - mutex_.AssertHeld(); - return HandleCreateObjectRequest(client, message, fallback_allocator, result, - spilling_required); - }; + auto handle_create = + [this, client, message]( + bool fallback_allocator, PlasmaObject *result, bool *spilling_required) + ABSL_NO_THREAD_SAFETY_ANALYSIS { + mutex_.AssertHeld(); + return HandleCreateObjectRequest( + client, message, fallback_allocator, result, spilling_required); + }; if (request->try_immediately()) { RAY_LOG(DEBUG) << "Received request to create object " << object_id @@ -403,8 +413,8 @@ Status PlasmaStore::ProcessMessage(const std::shared_ptr &client, std::vector object_ids_to_get; int64_t timeout_ms; bool is_from_worker; - RAY_RETURN_NOT_OK(ReadGetRequest(input, input_size, object_ids_to_get, &timeout_ms, - &is_from_worker)); + RAY_RETURN_NOT_OK(ReadGetRequest( + input, input_size, object_ids_to_get, &timeout_ms, &is_from_worker)); ProcessGetRequest(client, object_ids_to_get, timeout_ms, is_from_worker); } break; case fb::MessageType::PlasmaReleaseRequest: { @@ -461,8 +471,9 @@ Status PlasmaStore::ProcessMessage(const std::shared_ptr &client, } void PlasmaStore::DoAccept() { - acceptor_.async_accept(socket_, boost::bind(&PlasmaStore::ConnectClient, this, - boost::asio::placeholders::error)); + acceptor_.async_accept( + socket_, + boost::bind(&PlasmaStore::ConnectClient, this, boost::asio::placeholders::error)); } void PlasmaStore::ProcessCreateRequests() { @@ -501,7 +512,8 @@ void PlasmaStore::ProcessCreateRequests() { } void PlasmaStore::ReplyToCreateClient(const std::shared_ptr &client, - const ObjectID &object_id, uint64_t req_id) { + const ObjectID &object_id, + uint64_t req_id) { PlasmaObject result = {}; PlasmaError error; bool finished = create_request_queue_.GetRequestResult(req_id, &result, &error); @@ -533,7 +545,8 @@ void PlasmaStore::PrintAndRecordDebugDump() const { RecordMetrics(); RAY_LOG(INFO) << GetDebugDump(); stats_timer_ = execute_after( - io_context_, [this]() { PrintAndRecordDebugDump(); }, + io_context_, + [this]() { PrintAndRecordDebugDump(); }, RayConfig::instance().event_stats_print_interval_ms()); } diff --git a/src/ray/object_manager/plasma/store.h b/src/ray/object_manager/plasma/store.h index 0a043a020..758900305 100644 --- a/src/ray/object_manager/plasma/store.h +++ b/src/ray/object_manager/plasma/store.h @@ -53,8 +53,10 @@ using flatbuf::PlasmaError; class PlasmaStore { public: - PlasmaStore(instrumented_io_context &main_service, IAllocator &allocator, - const std::string &socket_name, uint32_t delay_on_oom_ms, + PlasmaStore(instrumented_io_context &main_service, + IAllocator &allocator, + const std::string &socket_name, + uint32_t delay_on_oom_ms, float object_spilling_threshold, ray::SpillObjectsCallback spill_objects_callback, std::function object_store_full_callback, @@ -122,7 +124,8 @@ class PlasmaStore { /// plasma_release. PlasmaError CreateObject(const ray::ObjectInfo &object_info, plasma::flatbuf::ObjectSource source, - const std::shared_ptr &client, bool fallback_allocator, + const std::shared_ptr &client, + bool fallback_allocator, PlasmaObject *result) EXCLUSIVE_LOCKS_REQUIRED(mutex_); /// Abort a created but unsealed object. If the client is not the @@ -156,7 +159,8 @@ class PlasmaStore { /// \param object_ids Object IDs of the objects to be gotten. /// \param timeout_ms The timeout for the get request in milliseconds. void ProcessGetRequest(const std::shared_ptr &client, - const std::vector &object_ids, int64_t timeout_ms, + const std::vector &object_ids, + int64_t timeout_ms, bool is_from_worker) EXCLUSIVE_LOCKS_REQUIRED(mutex_); /// Process queued requests to create an object. @@ -194,13 +198,14 @@ class PlasmaStore { PlasmaError HandleCreateObjectRequest(const std::shared_ptr &client, const std::vector &message, - bool fallback_allocator, PlasmaObject *object, + bool fallback_allocator, + PlasmaObject *object, bool *spilling_required) EXCLUSIVE_LOCKS_REQUIRED(mutex_); void ReplyToCreateClient(const std::shared_ptr &client, - const ObjectID &object_id, uint64_t req_id) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + const ObjectID &object_id, + uint64_t req_id) EXCLUSIVE_LOCKS_REQUIRED(mutex_); void AddToClientObjectIds(const ObjectID &object_id, const std::shared_ptr &client) diff --git a/src/ray/object_manager/plasma/store_runner.cc b/src/ray/object_manager/plasma/store_runner.cc index 97a1db3e3..b5727d0d8 100644 --- a/src/ray/object_manager/plasma/store_runner.cc +++ b/src/ray/object_manager/plasma/store_runner.cc @@ -15,8 +15,10 @@ namespace internal { void SetMallocGranularity(int value); } -PlasmaStoreRunner::PlasmaStoreRunner(std::string socket_name, int64_t system_memory, - bool hugepages_enabled, std::string plasma_directory, +PlasmaStoreRunner::PlasmaStoreRunner(std::string socket_name, + int64_t system_memory, + bool hugepages_enabled, + std::string plasma_directory, std::string fallback_directory) : hugepages_enabled_(hugepages_enabled) { // Sanity check. @@ -86,13 +88,17 @@ void PlasmaStoreRunner::Start(ray::SpillObjectsCallback spill_objects_callback, RAY_LOG(DEBUG) << "starting server listening on " << socket_name_; { absl::MutexLock lock(&store_runner_mutex_); - allocator_ = std::make_unique(plasma_directory_, fallback_directory_, - hugepages_enabled_, system_memory_); - store_.reset(new PlasmaStore(main_service_, *allocator_, socket_name_, + allocator_ = std::make_unique( + plasma_directory_, fallback_directory_, hugepages_enabled_, system_memory_); + store_.reset(new PlasmaStore(main_service_, + *allocator_, + socket_name_, RayConfig::instance().object_store_full_delay_ms(), RayConfig::instance().object_spilling_threshold(), - spill_objects_callback, object_store_full_callback, - add_object_callback, delete_object_callback)); + spill_objects_callback, + object_store_full_callback, + add_object_callback, + delete_object_callback)); store_->Start(); } main_service_.run(); diff --git a/src/ray/object_manager/plasma/store_runner.h b/src/ray/object_manager/plasma/store_runner.h index 955b66cba..e2962be4d 100644 --- a/src/ray/object_manager/plasma/store_runner.h +++ b/src/ray/object_manager/plasma/store_runner.h @@ -12,8 +12,10 @@ namespace plasma { class PlasmaStoreRunner { public: - PlasmaStoreRunner(std::string socket_name, int64_t system_memory, - bool hugepages_enabled, std::string plasma_directory, + PlasmaStoreRunner(std::string socket_name, + int64_t system_memory, + bool hugepages_enabled, + std::string plasma_directory, std::string fallback_directory); void Start(ray::SpillObjectsCallback spill_objects_callback, std::function object_store_full_callback, diff --git a/src/ray/object_manager/plasma/test/eviction_policy_test.cc b/src/ray/object_manager/plasma/test/eviction_policy_test.cc index ebd2d13d2..abe8071f0 100644 --- a/src/ray/object_manager/plasma/test/eviction_policy_test.cc +++ b/src/ray/object_manager/plasma/test/eviction_policy_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ray/object_manager/plasma/eviction_policy.h" + #include "gmock/gmock.h" #include "gtest/gtest.h" #include "ray/object_manager/plasma/object_store.h" @@ -87,8 +88,10 @@ class MockAllocator : public IAllocator { class MockObjectStore : public IObjectStore { public: - MOCK_METHOD3(CreateObject, const LocalObject *(const ray::ObjectInfo &, - plasma::flatbuf::ObjectSource, bool)); + MOCK_METHOD3(CreateObject, + const LocalObject *(const ray::ObjectInfo &, + plasma::flatbuf::ObjectSource, + bool)); MOCK_CONST_METHOD1(GetObject, const LocalObject *(const ObjectID &)); MOCK_METHOD1(SealObject, const LocalObject *(const ObjectID &)); MOCK_METHOD1(DeleteObject, bool(const ObjectID &)); diff --git a/src/ray/object_manager/plasma/test/fallback_allocator_test.cc b/src/ray/object_manager/plasma/test/fallback_allocator_test.cc index e9de42bc9..d76a2c106 100644 --- a/src/ray/object_manager/plasma/test/fallback_allocator_test.cc +++ b/src/ray/object_manager/plasma/test/fallback_allocator_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include + #include "gtest/gtest.h" #include "ray/object_manager/plasma/plasma_allocator.h" @@ -33,8 +34,10 @@ TEST(FallbackPlasmaAllocatorTest, FallbackPassThroughTest) { auto fallback_directory = CreateTestDir(); int64_t kLimit = 256 * sizeof(size_t) + 2 * kMB; int64_t object_size = 900 * 1024; - PlasmaAllocator allocator(plasma_directory, fallback_directory, - /* hugepage_enabled */ false, kLimit); + PlasmaAllocator allocator(plasma_directory, + fallback_directory, + /* hugepage_enabled */ false, + kLimit); EXPECT_EQ(kLimit, allocator.GetFootprintLimit()); diff --git a/src/ray/object_manager/plasma/test/object_lifecycle_manager_test.cc b/src/ray/object_manager/plasma/test/object_lifecycle_manager_test.cc index 9705fa0cd..039b1b8a8 100644 --- a/src/ray/object_manager/plasma/test/object_lifecycle_manager_test.cc +++ b/src/ray/object_manager/plasma/test/object_lifecycle_manager_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "ray/object_manager/plasma/object_lifecycle_manager.h" + #include #include "absl/random/random.h" @@ -19,8 +21,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "ray/object_manager/plasma/object_lifecycle_manager.h" - using namespace ray; using namespace testing; @@ -39,8 +39,10 @@ class MockEvictionPolicy : public IEvictionPolicy { class MockObjectStore : public IObjectStore { public: - MOCK_METHOD3(CreateObject, const LocalObject *(const ray::ObjectInfo &, - plasma::flatbuf::ObjectSource, bool)); + MOCK_METHOD3(CreateObject, + const LocalObject *(const ray::ObjectInfo &, + plasma::flatbuf::ObjectSource, + bool)); MOCK_CONST_METHOD1(GetObject, const LocalObject *(const ObjectID &)); MOCK_METHOD1(SealObject, const LocalObject *(const ObjectID &)); MOCK_METHOD1(DeleteObject, bool(const ObjectID &)); @@ -54,9 +56,10 @@ struct ObjectLifecycleManagerTest : public Test { auto object_store = std::make_unique(); eviction_policy_ = eviction_policy.get(); object_store_ = object_store.get(); - manager_ = std::make_unique( - ObjectLifecycleManager(std::move(object_store), std::move(eviction_policy), - [this](auto &id) { notify_deleted_ids_.push_back(id); })); + manager_ = std::make_unique(ObjectLifecycleManager( + std::move(object_store), std::move(eviction_policy), [this](auto &id) { + notify_deleted_ids_.push_back(id); + })); sealed_object_.state = ObjectState::PLASMA_SEALED; not_sealed_object_.state = ObjectState::PLASMA_CREATED; one_ref_object_.state = ObjectState::PLASMA_SEALED; diff --git a/src/ray/object_manager/plasma/test/object_store_test.cc b/src/ray/object_manager/plasma/test/object_store_test.cc index 6b6dc55a8..8520d4893 100644 --- a/src/ray/object_manager/plasma/test/object_store_test.cc +++ b/src/ray/object_manager/plasma/test/object_store_test.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "ray/object_manager/plasma/object_store.h" + #include + #include "absl/random/random.h" #include "absl/strings/str_format.h" #include "gmock/gmock.h" @@ -38,9 +40,14 @@ Allocation CreateAllocation(Allocation alloc, int64_t size) { } const std::string Serialize(const Allocation &allocation) { - return absl::StrFormat("%p/%d/%d/%d/%d/%d/%d", allocation.address, allocation.size, - allocation.fd.first, allocation.fd.second, allocation.offset, - allocation.device_num, allocation.mmap_size); + return absl::StrFormat("%p/%d/%d/%d/%d/%d/%d", + allocation.address, + allocation.size, + allocation.fd.first, + allocation.fd.second, + allocation.offset, + allocation.device_num, + allocation.mmap_size); } ObjectInfo CreateObjectInfo(ObjectID object_id, int64_t object_size) { diff --git a/src/ray/object_manager/plasma/test/stats_collector_test.cc b/src/ray/object_manager/plasma/test/stats_collector_test.cc index b8570f0c8..0f885b7fa 100644 --- a/src/ray/object_manager/plasma/test/stats_collector_test.cc +++ b/src/ray/object_manager/plasma/test/stats_collector_test.cc @@ -15,7 +15,6 @@ #include #include "absl/random/random.h" - #include "ray/object_manager/plasma/object_lifecycle_manager.h" using namespace ray; @@ -174,9 +173,10 @@ struct ObjectStatsCollectorTest : public Test { }; TEST_F(ObjectStatsCollectorTest, CreateAndAbort) { - std::vector sources = { - ObjectSource::CreatedByWorker, ObjectSource::RestoredFromStorage, - ObjectSource::ReceivedFromRemoteRaylet, ObjectSource::ErrorStoredByRaylet}; + std::vector sources = {ObjectSource::CreatedByWorker, + ObjectSource::RestoredFromStorage, + ObjectSource::ReceivedFromRemoteRaylet, + ObjectSource::ErrorStoredByRaylet}; for (auto source : sources) { int64_t size = Random(100); @@ -193,9 +193,10 @@ TEST_F(ObjectStatsCollectorTest, CreateAndAbort) { } TEST_F(ObjectStatsCollectorTest, CreateAndDelete) { - std::vector sources = { - ObjectSource::CreatedByWorker, ObjectSource::RestoredFromStorage, - ObjectSource::ReceivedFromRemoteRaylet, ObjectSource::ErrorStoredByRaylet}; + std::vector sources = {ObjectSource::CreatedByWorker, + ObjectSource::RestoredFromStorage, + ObjectSource::ReceivedFromRemoteRaylet, + ObjectSource::ErrorStoredByRaylet}; for (auto source : sources) { int64_t size = Random(100); @@ -219,9 +220,10 @@ TEST_F(ObjectStatsCollectorTest, CreateAndDelete) { } TEST_F(ObjectStatsCollectorTest, Eviction) { - std::vector sources = { - ObjectSource::CreatedByWorker, ObjectSource::RestoredFromStorage, - ObjectSource::ReceivedFromRemoteRaylet, ObjectSource::ErrorStoredByRaylet}; + std::vector sources = {ObjectSource::CreatedByWorker, + ObjectSource::RestoredFromStorage, + ObjectSource::ReceivedFromRemoteRaylet, + ObjectSource::ErrorStoredByRaylet}; int64_t size = 100; for (auto source : sources) { diff --git a/src/ray/object_manager/pull_manager.cc b/src/ray/object_manager/pull_manager.cc index 8f45403c0..a66465ebb 100644 --- a/src/ray/object_manager/pull_manager.cc +++ b/src/ray/object_manager/pull_manager.cc @@ -20,12 +20,14 @@ namespace ray { PullManager::PullManager( - NodeID &self_node_id, const std::function object_is_local, + NodeID &self_node_id, + const std::function object_is_local, const std::function send_pull_request, const std::function cancel_pull_request, const std::function fail_pull_request, const RestoreSpilledObjectCallback restore_spilled_object, - const std::function get_time_seconds, int pull_timeout_ms, + const std::function get_time_seconds, + int pull_timeout_ms, int64_t num_bytes_available, std::function(const ObjectID &)> pin_object, std::function get_locally_spilled_object_url) @@ -186,7 +188,8 @@ bool PullManager::ActivateNextPullBundleRequest(const Queue &bundles, } void PullManager::DeactivatePullBundleRequest( - const Queue &bundles, const Queue::iterator &request_it, + const Queue &bundles, + const Queue::iterator &request_it, uint64_t *highest_req_id_being_pulled, std::unordered_set *objects_to_cancel) { for (const auto &ref : request_it->second.objects) { @@ -223,8 +226,12 @@ void PullManager::DeactivatePullBundleRequest( } void PullManager::DeactivateUntilMarginAvailable( - const std::string &debug_name, Queue &bundles, int retain_min, int64_t quota_margin, - uint64_t *highest_id_for_bundle, std::unordered_set *object_ids_to_cancel) { + const std::string &debug_name, + Queue &bundles, + int retain_min, + int64_t quota_margin, + uint64_t *highest_id_for_bundle, + std::unordered_set *object_ids_to_cancel) { while (RemainingQuota() < quota_margin && *highest_id_for_bundle != 0) { if (num_active_bundles_ <= retain_min) { return; @@ -234,8 +241,8 @@ void PullManager::DeactivateUntilMarginAvailable( << " num bytes available: " << num_bytes_available_; const auto last_request_it = bundles.find(*highest_id_for_bundle); RAY_CHECK(last_request_it != bundles.end()); - DeactivatePullBundleRequest(bundles, last_request_it, highest_id_for_bundle, - object_ids_to_cancel); + DeactivatePullBundleRequest( + bundles, last_request_it, highest_id_for_bundle, object_ids_to_cancel); } } @@ -268,19 +275,25 @@ void PullManager::UpdatePullsBasedOnAvailableMemory(int64_t num_bytes_available) while (get_requests_remaining) { int64_t margin_required = NextRequestBundleSize(get_request_bundles_, highest_get_req_id_being_pulled_); - DeactivateUntilMarginAvailable("task args request", task_argument_bundles_, - /*retain_min=*/0, /*quota_margin=*/margin_required, + DeactivateUntilMarginAvailable("task args request", + task_argument_bundles_, + /*retain_min=*/0, + /*quota_margin=*/margin_required, &highest_task_req_id_being_pulled_, &object_ids_to_cancel); - DeactivateUntilMarginAvailable("wait request", wait_request_bundles_, - /*retain_min=*/0, /*quota_margin=*/margin_required, + DeactivateUntilMarginAvailable("wait request", + wait_request_bundles_, + /*retain_min=*/0, + /*quota_margin=*/margin_required, &highest_wait_req_id_being_pulled_, &object_ids_to_cancel); // Activate the next get request unconditionally. - get_requests_remaining = ActivateNextPullBundleRequest( - get_request_bundles_, &highest_get_req_id_being_pulled_, - /*respect_quota=*/false, &objects_to_pull); + get_requests_remaining = + ActivateNextPullBundleRequest(get_request_bundles_, + &highest_get_req_id_being_pulled_, + /*respect_quota=*/false, + &objects_to_pull); } // Do the same but for wait requests (medium priority). @@ -288,30 +301,41 @@ void PullManager::UpdatePullsBasedOnAvailableMemory(int64_t num_bytes_available) while (wait_requests_remaining) { int64_t margin_required = NextRequestBundleSize(wait_request_bundles_, highest_wait_req_id_being_pulled_); - DeactivateUntilMarginAvailable("task args request", task_argument_bundles_, - /*retain_min=*/0, /*quota_margin=*/margin_required, + DeactivateUntilMarginAvailable("task args request", + task_argument_bundles_, + /*retain_min=*/0, + /*quota_margin=*/margin_required, &highest_task_req_id_being_pulled_, &object_ids_to_cancel); // Activate the next wait request if we have space. - wait_requests_remaining = ActivateNextPullBundleRequest( - wait_request_bundles_, &highest_wait_req_id_being_pulled_, - /*respect_quota=*/true, &objects_to_pull); + wait_requests_remaining = + ActivateNextPullBundleRequest(wait_request_bundles_, + &highest_wait_req_id_being_pulled_, + /*respect_quota=*/true, + &objects_to_pull); } // Do the same but for task arg requests (lowest priority). // allowed for task arg requests. while (ActivateNextPullBundleRequest(task_argument_bundles_, &highest_task_req_id_being_pulled_, - /*respect_quota=*/true, &objects_to_pull)) { + /*respect_quota=*/true, + &objects_to_pull)) { } // While we are over capacity, deactivate requests starting from the back of the queues. - DeactivateUntilMarginAvailable( - "task args request", task_argument_bundles_, /*retain_min=*/1, /*quota_margin=*/0L, - &highest_task_req_id_being_pulled_, &object_ids_to_cancel); - DeactivateUntilMarginAvailable("wait request", wait_request_bundles_, /*retain_min=*/1, - /*quota_margin=*/0L, &highest_wait_req_id_being_pulled_, + DeactivateUntilMarginAvailable("task args request", + task_argument_bundles_, + /*retain_min=*/1, + /*quota_margin=*/0L, + &highest_task_req_id_being_pulled_, + &object_ids_to_cancel); + DeactivateUntilMarginAvailable("wait request", + wait_request_bundles_, + /*retain_min=*/1, + /*quota_margin=*/0L, + &highest_wait_req_id_being_pulled_, &object_ids_to_cancel); // Call the cancellation callbacks outside of the lock. @@ -354,8 +378,8 @@ std::vector PullManager::CancelPull(uint64_t request_id) { // If the pull request was being actively pulled, deactivate it now. if (bundle_it->first <= *highest_req_id_being_pulled) { std::unordered_set object_ids_to_cancel; - DeactivatePullBundleRequest(*request_queue, bundle_it, highest_req_id_being_pulled, - &object_ids_to_cancel); + DeactivatePullBundleRequest( + *request_queue, bundle_it, highest_req_id_being_pulled, &object_ids_to_cancel); for (const auto &obj_id : object_ids_to_cancel) { // Call the cancellation callback outside of the lock. cancel_pull_request_(obj_id); @@ -389,7 +413,8 @@ std::vector PullManager::CancelPull(uint64_t request_id) { void PullManager::OnLocationChange(const ObjectID &object_id, const std::unordered_set &client_ids, const std::string &spilled_url, - const NodeID &spilled_node_id, bool pending_creation, + const NodeID &spilled_node_id, + bool pending_creation, size_t object_size) { // Exit if the Pull request has already been fulfilled or canceled. auto it = object_pull_requests_.find(object_id); @@ -483,13 +508,13 @@ void PullManager::TryToMakeObjectLocal(const ObjectID &object_id) { if (!direct_restore_url.empty()) { // Select an url from the object directory update UpdateRetryTimer(request, object_id); - restore_spilled_object_(object_id, direct_restore_url, - [object_id](const ray::Status &status) { - if (!status.ok()) { - RAY_LOG(ERROR) << "Object restore for " << object_id - << " failed, will retry later: " << status; - } - }); + restore_spilled_object_( + object_id, direct_restore_url, [object_id](const ray::Status &status) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Object restore for " << object_id + << " failed, will retry later: " << status; + } + }); return; } diff --git a/src/ray/object_manager/pull_manager.h b/src/ray/object_manager/pull_manager.h index 7703aa84a..d16f8bf14 100644 --- a/src/ray/object_manager/pull_manager.h +++ b/src/ray/object_manager/pull_manager.h @@ -60,12 +60,14 @@ class PullManager { /// \param restore_spilled_object A callback which should /// retrieve an spilled object from the external store. PullManager( - NodeID &self_node_id, const std::function object_is_local, + NodeID &self_node_id, + const std::function object_is_local, const std::function send_pull_request, const std::function cancel_pull_request, const std::function fail_pull_request, const RestoreSpilledObjectCallback restore_spilled_object, - const std::function get_time_seconds, int pull_timeout_ms, + const std::function get_time_seconds, + int pull_timeout_ms, int64_t num_bytes_available, std::function(const ObjectID &object_id)> pin_object, std::function get_locally_spilled_object_url); @@ -111,8 +113,10 @@ class PullManager { /// objects we can safely pull. void OnLocationChange(const ObjectID &object_id, const std::unordered_set &client_ids, - const std::string &spilled_url, const NodeID &spilled_node_id, - bool pending_creation, size_t object_size); + const std::string &spilled_url, + const NodeID &spilled_node_id, + bool pending_creation, + size_t object_size); /// Cancel an existing pull request. /// @@ -272,8 +276,10 @@ class PullManager { /// bundles (in any queue) below this threshold. /// \param quota_margin Keep deactivating bundles until this amount of quota margin /// becomes available. - void DeactivateUntilMarginAvailable(const std::string &debug_name, Queue &bundles, - int retain_min, int64_t quota_margin, + void DeactivateUntilMarginAvailable(const std::string &debug_name, + Queue &bundles, + int retain_min, + int64_t quota_margin, uint64_t *highest_id_for_bundle, std::unordered_set *objects_to_cancel); diff --git a/src/ray/object_manager/push_manager.cc b/src/ray/object_manager/push_manager.cc index 93c47ecaf..4234440da 100644 --- a/src/ray/object_manager/push_manager.cc +++ b/src/ray/object_manager/push_manager.cc @@ -20,7 +20,8 @@ namespace ray { -void PushManager::StartPush(const NodeID &dest_id, const ObjectID &obj_id, +void PushManager::StartPush(const NodeID &dest_id, + const ObjectID &obj_id, int64_t num_chunks, std::function send_chunk_fn) { auto push_id = std::make_pair(dest_id, obj_id); diff --git a/src/ray/object_manager/push_manager.h b/src/ray/object_manager/push_manager.h index 77c263758..609c0ab92 100644 --- a/src/ray/object_manager/push_manager.h +++ b/src/ray/object_manager/push_manager.h @@ -47,7 +47,9 @@ class PushManager { /// \param send_chunk_fn This function will be called with args 0...{num_chunks-1}. /// The caller promises to call PushManager::OnChunkComplete() /// once a call to send_chunk_fn finishes. - void StartPush(const NodeID &dest_id, const ObjectID &obj_id, int64_t num_chunks, + void StartPush(const NodeID &dest_id, + const ObjectID &obj_id, + int64_t num_chunks, std::function send_chunk_fn); /// Called every time a chunk completes to trigger additional sends. diff --git a/src/ray/object_manager/spilled_object_reader.cc b/src/ray/object_manager/spilled_object_reader.cc index 662974708..6bf66e2b2 100644 --- a/src/ray/object_manager/spilled_object_reader.cc +++ b/src/ray/object_manager/spilled_object_reader.cc @@ -30,8 +30,8 @@ SpilledObjectReader::CreateSpilledObjectReader(const std::string &object_url) { uint64_t object_offset = 0; uint64_t object_size = 0; - if (!SpilledObjectReader::ParseObjectURL(object_url, file_path, object_offset, - object_size)) { + if (!SpilledObjectReader::ParseObjectURL( + object_url, file_path, object_offset, object_size)) { RAY_LOG(WARNING) << "Failed to parse spilled object url: " << object_url; return absl::optional(); } @@ -43,16 +43,25 @@ SpilledObjectReader::CreateSpilledObjectReader(const std::string &object_url) { rpc::Address owner_address; std::ifstream is(file_path, std::ios::binary); - if (!is || !SpilledObjectReader::ParseObjectHeader(is, object_offset, data_offset, - data_size, metadata_offset, - metadata_size, owner_address)) { + if (!is || !SpilledObjectReader::ParseObjectHeader(is, + object_offset, + data_offset, + data_size, + metadata_offset, + metadata_size, + owner_address)) { RAY_LOG(WARNING) << "Failed to parse object header for spilled object " << object_url; return absl::optional(); } return absl::optional( - SpilledObjectReader(std::move(file_path), object_size, data_offset, data_size, - metadata_offset, metadata_size, std::move(owner_address))); + SpilledObjectReader(std::move(file_path), + object_size, + data_offset, + data_size, + metadata_offset, + metadata_size, + std::move(owner_address))); } uint64_t SpilledObjectReader::GetDataSize() const { return data_size_; } @@ -63,9 +72,12 @@ const rpc::Address &SpilledObjectReader::GetOwnerAddress() const { return owner_address_; } -SpilledObjectReader::SpilledObjectReader(std::string file_path, uint64_t object_size, - uint64_t data_offset, uint64_t data_size, - uint64_t metadata_offset, uint64_t metadata_size, +SpilledObjectReader::SpilledObjectReader(std::string file_path, + uint64_t object_size, + uint64_t data_offset, + uint64_t data_size, + uint64_t metadata_offset, + uint64_t metadata_size, rpc::Address owner_address) : file_path_(std::move(file_path)), object_size_(object_size), @@ -105,8 +117,10 @@ SpilledObjectReader::SpilledObjectReader(std::string file_path, uint64_t object_ } /* static */ -bool SpilledObjectReader::ParseObjectHeader(std::istream &is, uint64_t object_offset, - uint64_t &data_offset, uint64_t &data_size, +bool SpilledObjectReader::ParseObjectHeader(std::istream &is, + uint64_t object_offset, + uint64_t &data_offset, + uint64_t &data_size, uint64_t &metadata_offset, uint64_t &metadata_size, rpc::Address &owner_address) { @@ -152,13 +166,15 @@ uint64_t SpilledObjectReader::ToUINT64(const std::string &s) { return result; } -bool SpilledObjectReader::ReadFromDataSection(uint64_t offset, uint64_t size, +bool SpilledObjectReader::ReadFromDataSection(uint64_t offset, + uint64_t size, char *output) const { std::ifstream is(file_path_, std::ios::binary); return is.seekg(data_offset_ + offset) && is.read(output, size); } -bool SpilledObjectReader::ReadFromMetadataSection(uint64_t offset, uint64_t size, +bool SpilledObjectReader::ReadFromMetadataSection(uint64_t offset, + uint64_t size, char *output) const { std::ifstream is(file_path_, std::ios::binary); return is.seekg(metadata_offset_ + offset) && is.read(output, size); diff --git a/src/ray/object_manager/spilled_object_reader.h b/src/ray/object_manager/spilled_object_reader.h index 349ba749c..1b5b3441c 100644 --- a/src/ray/object_manager/spilled_object_reader.h +++ b/src/ray/object_manager/spilled_object_reader.h @@ -41,13 +41,18 @@ class SpilledObjectReader : public IObjectReader { const rpc::Address &GetOwnerAddress() const override; bool ReadFromDataSection(uint64_t offset, uint64_t size, char *output) const override; - bool ReadFromMetadataSection(uint64_t offset, uint64_t size, + bool ReadFromMetadataSection(uint64_t offset, + uint64_t size, char *output) const override; private: - SpilledObjectReader(std::string file_path, uint64_t total_size, uint64_t data_offset, - uint64_t data_size, uint64_t metadata_offset, - uint64_t metadata_size, rpc::Address owner_address); + SpilledObjectReader(std::string file_path, + uint64_t total_size, + uint64_t data_offset, + uint64_t data_size, + uint64_t metadata_offset, + uint64_t metadata_size, + rpc::Address owner_address); /// Parse the object url in the form of {path}?offset={offset}&size={size}. /// Return false if parsing failed. @@ -57,8 +62,10 @@ class SpilledObjectReader : public IObjectReader { /// \param[out] object_offset offset of the object stored in the file.. /// \param[out] total_size object size in the file. /// \return bool. - static bool ParseObjectURL(const std::string &object_url, std::string &file_path, - uint64_t &object_offset, uint64_t &total_size); + static bool ParseObjectURL(const std::string &object_url, + std::string &file_path, + uint64_t &object_offset, + uint64_t &total_size); /// Read the istream, parse the object header according to the following format. /// Return false if the input stream is deleted or corrupted. @@ -80,9 +87,12 @@ class SpilledObjectReader : public IObjectReader { /// \param[out] metadata_size size of the metadata payload. /// \param[out] owner_address owner address. /// \return bool. - static bool ParseObjectHeader(std::istream &is, uint64_t object_offset, - uint64_t &data_offset, uint64_t &data_size, - uint64_t &metadata_offset, uint64_t &metadata_size, + static bool ParseObjectHeader(std::istream &is, + uint64_t object_offset, + uint64_t &data_offset, + uint64_t &data_size, + uint64_t &metadata_offset, + uint64_t &metadata_size, rpc::Address &owner_address); /// Read 8 bytes from inputstream and deserialize it as a little-endian diff --git a/src/ray/object_manager/test/get_request_queue_test.cc b/src/ray/object_manager/test/get_request_queue_test.cc index 8445dd715..308344949 100644 --- a/src/ray/object_manager/test/get_request_queue_test.cc +++ b/src/ray/object_manager/test/get_request_queue_test.cc @@ -35,7 +35,8 @@ class MockObjectLifecycleManager : public IObjectLifecycleManager { MOCK_METHOD3(CreateObject, std::pair( const ray::ObjectInfo &object_info, - plasma::flatbuf::ObjectSource source, bool fallback_allocator)); + plasma::flatbuf::ObjectSource source, + bool fallback_allocator)); MOCK_CONST_METHOD1(GetObject, const LocalObject *(const ObjectID &object_id)); MOCK_METHOD1(SealObject, const LocalObject *(const ObjectID &object_id)); MOCK_METHOD1(AbortObject, flatbuf::PlasmaError(const ObjectID &object_id)); @@ -103,7 +104,8 @@ TEST_F(GetRequestQueueTest, TestObjectSealed) { bool satisfied = false; MockObjectLifecycleManager object_lifecycle_manager; GetRequestQueue get_request_queue( - io_context_, object_lifecycle_manager, + io_context_, + object_lifecycle_manager, [&](const ObjectID &object_id, const auto &request) {}, [&](const std::shared_ptr &get_req) { satisfied = true; }); auto client = std::make_shared(); @@ -123,7 +125,8 @@ TEST_F(GetRequestQueueTest, TestObjectTimeout) { std::promise promise; MockObjectLifecycleManager object_lifecycle_manager; GetRequestQueue get_request_queue( - io_context_, object_lifecycle_manager, + io_context_, + object_lifecycle_manager, [&](const ObjectID &object_id, const auto &request) {}, [&](const std::shared_ptr &get_req) { promise.set_value(true); }); auto client = std::make_shared(); @@ -144,7 +147,8 @@ TEST_F(GetRequestQueueTest, TestObjectNotSealed) { std::promise promise; MockObjectLifecycleManager object_lifecycle_manager; GetRequestQueue get_request_queue( - io_context_, object_lifecycle_manager, + io_context_, + object_lifecycle_manager, [&](const ObjectID &object_id, const auto &request) {}, [&](const std::shared_ptr &get_req) { promise.set_value(true); }); auto client = std::make_shared(); @@ -167,7 +171,8 @@ TEST_F(GetRequestQueueTest, TestMultipleObjects) { std::promise promise1, promise2, promise3; MockObjectLifecycleManager object_lifecycle_manager; GetRequestQueue get_request_queue( - io_context_, object_lifecycle_manager, + io_context_, + object_lifecycle_manager, [&](const ObjectID &object_id, const auto &request) { if (object_id == object_id1) { promise1.set_value(true); @@ -203,7 +208,8 @@ TEST_F(GetRequestQueueTest, TestMultipleObjects) { TEST_F(GetRequestQueueTest, TestDuplicateObjects) { MockObjectLifecycleManager object_lifecycle_manager; GetRequestQueue get_request_queue( - io_context_, object_lifecycle_manager, + io_context_, + object_lifecycle_manager, [&](const ObjectID &object_id, const auto &request) {}, [&](const std::shared_ptr &get_req) {}); auto client = std::make_shared(); @@ -227,7 +233,8 @@ TEST_F(GetRequestQueueTest, TestDuplicateObjects) { TEST_F(GetRequestQueueTest, TestRemoveAll) { MockObjectLifecycleManager object_lifecycle_manager; GetRequestQueue get_request_queue( - io_context_, object_lifecycle_manager, + io_context_, + object_lifecycle_manager, [&](const ObjectID &object_id, const auto &request) {}, [&](const std::shared_ptr &get_req) {}); auto client = std::make_shared(); @@ -255,7 +262,8 @@ TEST_F(GetRequestQueueTest, TestRemoveAll) { TEST_F(GetRequestQueueTest, TestRemoveTwice) { MockObjectLifecycleManager object_lifecycle_manager; GetRequestQueue get_request_queue( - io_context_, object_lifecycle_manager, + io_context_, + object_lifecycle_manager, [&](const ObjectID &object_id, const auto &request) {}, [&](const std::shared_ptr &get_req) {}); auto client = std::make_shared(); diff --git a/src/ray/object_manager/test/object_buffer_pool_test.cc b/src/ray/object_manager/test/object_buffer_pool_test.cc index 0dac3bf1d..1ae4602f0 100644 --- a/src/ray/object_manager/test/object_buffer_pool_test.cc +++ b/src/ray/object_manager/test/object_buffer_pool_test.cc @@ -42,7 +42,8 @@ class MockPlasmaClient : public plasma::PlasmaClientInterface { MOCK_METHOD0(Disconnect, ray::Status()); MOCK_METHOD4(Get, - ray::Status(const std::vector &object_ids, int64_t timeout_ms, + ray::Status(const std::vector &object_ids, + int64_t timeout_ms, std::vector *object_buffers, bool is_from_worker)); @@ -52,8 +53,10 @@ class MockPlasmaClient : public plasma::PlasmaClientInterface { ray::Status CreateAndSpillIfNeeded(const ObjectID &object_id, const ray::rpc::Address &owner_address, - int64_t data_size, const uint8_t *metadata, - int64_t metadata_size, std::shared_ptr *data, + int64_t data_size, + const uint8_t *metadata, + int64_t metadata_size, + std::shared_ptr *data, plasma::flatbuf::ObjectSource source, int device_num) { *data = std::make_shared(data_size); diff --git a/src/ray/object_manager/test/ownership_based_object_directory_test.cc b/src/ray/object_manager/test/ownership_based_object_directory_test.cc index 561ec9309..dff9a70b5 100644 --- a/src/ray/object_manager/test/ownership_based_object_directory_test.cc +++ b/src/ray/object_manager/test/ownership_based_object_directory_test.cc @@ -61,7 +61,8 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { return true; } - void AssertObjectIDState(const WorkerID &worker_id, const ObjectID &object_id, + void AssertObjectIDState(const WorkerID &worker_id, + const ObjectID &object_id, rpc::ObjectLocationState state) { auto it = buffered_object_locations_.find(worker_id); RAY_CHECK(it != buffered_object_locations_.end()) @@ -113,7 +114,10 @@ class OwnershipBasedObjectDirectoryTest : public ::testing::Test { subscriber_(std::make_shared()), owner_client(std::make_shared()), client_pool([&](const rpc::Address &addr) { return owner_client; }), - obod_(io_service_, gcs_client_mock_, subscriber_.get(), &client_pool, + obod_(io_service_, + gcs_client_mock_, + subscriber_.get(), + &client_pool, /*max_object_report_batch_size=*/20, [this](const ObjectID &object_id, const rpc::ErrorType &error_type) { MarkAsFailed(object_id, error_type); @@ -141,7 +145,8 @@ class OwnershipBasedObjectDirectoryTest : public ::testing::Test { return info; } - void AssertObjectIDState(const WorkerID &worker_id, const ObjectID &object_id, + void AssertObjectIDState(const WorkerID &worker_id, + const ObjectID &object_id, rpc::ObjectLocationState state) { owner_client->AssertObjectIDState(worker_id, object_id, state); } @@ -166,10 +171,11 @@ class OwnershipBasedObjectDirectoryTest : public ::testing::Test { } void HandleMessage(const rpc::WorkerObjectLocationsPubMessage &location_info, - const ObjectID &object_id, bool location_lookup_failed = false) { + const ObjectID &object_id, + bool location_lookup_failed = false) { // Mock for receiving a message from the pubsub layer. - obod_.ObjectLocationSubscriptionCallback(location_info, object_id, - location_lookup_failed); + obod_.ObjectLocationSubscriptionCallback( + location_info, object_id, location_lookup_failed); } int64_t max_batch_size = 20; @@ -191,9 +197,10 @@ TEST_F(OwnershipBasedObjectDirectoryTest, TestLocationUpdateBatchBasic) { { RAY_LOG(INFO) << "Object added basic."; auto object_info_added = CreateNewObjectInfo(owner_id); - obod_.ReportObjectAdded(object_info_added.object_id, current_node_id, - object_info_added); - AssertObjectIDState(object_info_added.owner_worker_id, object_info_added.object_id, + obod_.ReportObjectAdded( + object_info_added.object_id, current_node_id, object_info_added); + AssertObjectIDState(object_info_added.owner_worker_id, + object_info_added.object_id, rpc::ObjectLocationState::ADDED); ASSERT_TRUE(owner_client->ReplyUpdateObjectLocationBatch()); ASSERT_EQ(NumBatchRequestSent(), 1); @@ -204,10 +211,11 @@ TEST_F(OwnershipBasedObjectDirectoryTest, TestLocationUpdateBatchBasic) { { RAY_LOG(INFO) << "Object removed basic."; auto object_info_removed = CreateNewObjectInfo(owner_id); - obod_.ReportObjectRemoved(object_info_removed.object_id, current_node_id, - object_info_removed); + obod_.ReportObjectRemoved( + object_info_removed.object_id, current_node_id, object_info_removed); AssertObjectIDState(object_info_removed.owner_worker_id, - object_info_removed.object_id, rpc::ObjectLocationState::REMOVED); + object_info_removed.object_id, + rpc::ObjectLocationState::REMOVED); ASSERT_TRUE(owner_client->ReplyUpdateObjectLocationBatch()); ASSERT_EQ(NumBatchRequestSent(), 2); ASSERT_EQ(NumBatchReplied(), 2); @@ -227,7 +235,8 @@ TEST_F(OwnershipBasedObjectDirectoryTest, TestLocationUpdateBufferedUpdate) { ASSERT_EQ(NumBatchRequestSent(), 2); // For the same object ID, it should report the latest result (which is REMOVED). - AssertObjectIDState(object_info.owner_worker_id, object_info.object_id, + AssertObjectIDState(object_info.owner_worker_id, + object_info.object_id, rpc::ObjectLocationState::REMOVED); ASSERT_TRUE(owner_client->ReplyUpdateObjectLocationBatch()); @@ -253,9 +262,11 @@ TEST_F(OwnershipBasedObjectDirectoryTest, ASSERT_TRUE(owner_client->ReplyUpdateObjectLocationBatch()); ASSERT_EQ(NumBatchReplied(), 2); // For the same object ID, it should report the latest result (which is REMOVED). - AssertObjectIDState(object_info.owner_worker_id, object_info.object_id, + AssertObjectIDState(object_info.owner_worker_id, + object_info.object_id, rpc::ObjectLocationState::REMOVED); - AssertObjectIDState(object_info_2.owner_worker_id, object_info_2.object_id, + AssertObjectIDState(object_info_2.owner_worker_id, + object_info_2.object_id, rpc::ObjectLocationState::ADDED); AssertNoLeak(); } @@ -284,9 +295,11 @@ TEST_F(OwnershipBasedObjectDirectoryTest, TestLocationUpdateBufferedMultipleOwne ASSERT_EQ(NumBatchRequestSent(), 4); ASSERT_EQ(NumBatchReplied(), 2); // For the same object ID, it should report the latest result (which is REMOVED). - AssertObjectIDState(object_info.owner_worker_id, object_info.object_id, + AssertObjectIDState(object_info.owner_worker_id, + object_info.object_id, rpc::ObjectLocationState::REMOVED); - AssertObjectIDState(object_info_2.owner_worker_id, object_info_2.object_id, + AssertObjectIDState(object_info_2.owner_worker_id, + object_info_2.object_id, rpc::ObjectLocationState::ADDED); // Clean up reply and check assert. @@ -314,7 +327,8 @@ TEST_F(OwnershipBasedObjectDirectoryTest, TestLocationUpdateOneInFlightRequest) ASSERT_TRUE(owner_client->ReplyUpdateObjectLocationBatch()); ASSERT_EQ(NumBatchRequestSent(), 2); - AssertObjectIDState(object_info.owner_worker_id, object_info.object_id, + AssertObjectIDState(object_info.owner_worker_id, + object_info.object_id, rpc::ObjectLocationState::REMOVED); // After it is replied, if there's no more entry in the buffer, it doesn't send a new @@ -353,7 +367,8 @@ TEST_F(OwnershipBasedObjectDirectoryTest, TestLocationUpdateMaxBatchSize) { // Check if object id states are updated properly. for (const auto &object_info : object_infos) { - AssertObjectIDState(object_info.owner_worker_id, object_info.object_id, + AssertObjectIDState(object_info.owner_worker_id, + object_info.object_id, rpc::ObjectLocationState::REMOVED); } AssertNoLeak(); @@ -394,11 +409,15 @@ TEST_F(OwnershipBasedObjectDirectoryTest, TestNotifyOnUpdate) { EXPECT_CALL(*subscriber_, Subscribe(_, _, _, _, _, _, _)).WillOnce(Return(true)); ASSERT_TRUE( obod_ - .SubscribeObjectLocations( - callback_id, obj_id, rpc::Address(), - [&](const ObjectID &object_id, const std::unordered_set &client_ids, - const std::string &spilled_url, const NodeID &spilled_node_id, - bool pending_creation, size_t object_size) { num_callbacks++; }) + .SubscribeObjectLocations(callback_id, + obj_id, + rpc::Address(), + [&](const ObjectID &object_id, + const std::unordered_set &client_ids, + const std::string &spilled_url, + const NodeID &spilled_node_id, + bool pending_creation, + size_t object_size) { num_callbacks++; }) .ok()); ASSERT_EQ(num_callbacks, 0); diff --git a/src/ray/object_manager/test/pull_manager_test.cc b/src/ray/object_manager/test/pull_manager_test.cc index ec8e85301..e5a21504b 100644 --- a/src/ray/object_manager/test/pull_manager_test.cc +++ b/src/ray/object_manager/test/pull_manager_test.cc @@ -32,18 +32,22 @@ class PullManagerTestWithCapacity { num_restore_spilled_object_calls_(0), fake_time_(0), pull_manager_( - self_node_id_, [this](const ObjectID &object_id) { return object_is_local_; }, + self_node_id_, + [this](const ObjectID &object_id) { return object_is_local_; }, [this](const ObjectID &object_id, const NodeID &node_id) { num_send_pull_request_calls_++; }, [this](const ObjectID &object_id) { num_abort_calls_[object_id]++; }, [this](const ObjectID &object_id) { timed_out_objects_.insert(object_id); }, - [this](const ObjectID &, const std::string &, + [this](const ObjectID &, + const std::string &, std::function callback) { num_restore_spilled_object_calls_++; restore_object_callback_ = callback; }, - [this]() { return fake_time_; }, 10000, num_available_bytes, + [this]() { return fake_time_; }, + 10000, + num_available_bytes, [this](const ObjectID &object_id) { return PinReturn(); }, [this](const ObjectID &object_id) { return GetLocalSpilledObjectURL(object_id); @@ -200,8 +204,8 @@ TEST_P(PullManagerTest, TestRestoreSpilledObjectRemote) { NodeID node_that_object_spilled = NodeID::FromRandom(); fake_time_ += 10.; ObjectSpilled(obj1, "remote_url/foo/bar"); - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", - node_that_object_spilled, false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", node_that_object_spilled, false, 0); // We request a remote pull to restore the object. ASSERT_EQ(num_send_pull_request_calls_, 1); @@ -209,8 +213,8 @@ TEST_P(PullManagerTest, TestRestoreSpilledObjectRemote) { // No retry yet. ObjectSpilled(obj1, "remote_url/foo/bar"); - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", - node_that_object_spilled, false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", node_that_object_spilled, false, 0); ASSERT_EQ(num_send_pull_request_calls_, 1); ASSERT_EQ(num_restore_spilled_object_calls_, 0); @@ -218,16 +222,16 @@ TEST_P(PullManagerTest, TestRestoreSpilledObjectRemote) { client_ids.insert(node_that_object_spilled); fake_time_ += 10.; ObjectSpilled(obj1, "remote_url/foo/bar"); - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", - node_that_object_spilled, false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", node_that_object_spilled, false, 0); ASSERT_EQ(num_send_pull_request_calls_, 2); ASSERT_EQ(num_restore_spilled_object_calls_, 0); // Don't restore an object if it's local. object_is_local_ = true; ObjectSpilled(obj1, "remote_url/foo/bar"); - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", - NodeID::FromRandom(), false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", NodeID::FromRandom(), false, 0); ASSERT_EQ(num_send_pull_request_calls_, 2); ASSERT_EQ(num_restore_spilled_object_calls_, 0); @@ -262,8 +266,8 @@ TEST_P(PullManagerTest, TestRestoreSpilledObjectLocal) { fake_time_ += 10.; ObjectSpilled(obj1, "remote_url/foo/bar"); - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", self_node_id_, - false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", self_node_id_, false, 0); // We request a local restore. ASSERT_EQ(num_send_pull_request_calls_, 0); @@ -271,16 +275,16 @@ TEST_P(PullManagerTest, TestRestoreSpilledObjectLocal) { // No retry yet. ObjectSpilled(obj1, "remote_url/foo/bar"); - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", self_node_id_, - false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", self_node_id_, false, 0); ASSERT_EQ(num_send_pull_request_calls_, 0); ASSERT_EQ(num_restore_spilled_object_calls_, 1); // The call can be retried after a delay. fake_time_ += 10.; ObjectSpilled(obj1, "remote_url/foo/bar"); - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", self_node_id_, - false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", self_node_id_, false, 0); ASSERT_EQ(num_send_pull_request_calls_, 0); ASSERT_EQ(num_restore_spilled_object_calls_, 2); @@ -327,8 +331,8 @@ TEST_P(PullManagerTest, TestRestoreSpilledObjectOnLocalStorage) { // The call can be retried after a delay, and the url in the remote object directory is // updated now. fake_time_ += 10.; - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", self_node_id_, - false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", self_node_id_, false, 0); ASSERT_EQ(num_send_pull_request_calls_, 0); ASSERT_EQ(num_restore_spilled_object_calls_, 2); @@ -367,16 +371,16 @@ TEST_P(PullManagerTest, TestRestoreSpilledObjectOnExternalStorage) { ObjectSpilled(obj1, ""); // If objects are spilled to external storages, the node id should be Nil(). // So this shouldn't invoke restoration. - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", self_node_id_, - false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", self_node_id_, false, 0); // We request a local restore. ASSERT_EQ(num_send_pull_request_calls_, 0); ASSERT_EQ(num_restore_spilled_object_calls_, 0); // Now Nil ID is properly updated. - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", NodeID::Nil(), - false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", NodeID::Nil(), false, 0); // We request a local restore. ASSERT_EQ(num_send_pull_request_calls_, 0); @@ -384,8 +388,8 @@ TEST_P(PullManagerTest, TestRestoreSpilledObjectOnExternalStorage) { // The call can be retried after a delay. fake_time_ += 10.; - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", NodeID::Nil(), - false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", NodeID::Nil(), false, 0); ASSERT_EQ(num_send_pull_request_calls_, 0); ASSERT_EQ(num_restore_spilled_object_calls_, 2); @@ -422,8 +426,8 @@ TEST_P(PullManagerTest, TestLoadBalancingRestorationRequest) { client_ids.insert(copy_node1); client_ids.insert(copy_node2); ObjectSpilled(obj1, "remote_url/foo/bar"); - pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar", - remote_node_that_spilled_object, false, 0); + pull_manager_.OnLocationChange( + obj1, client_ids, "remote_url/foo/bar", remote_node_that_spilled_object, false, 0); ASSERT_EQ(num_send_pull_request_calls_, 1); // Make sure the restore request wasn't sent since there are nodes that have a copied @@ -779,8 +783,8 @@ TEST_P(PullManagerWithAdmissionControlTest, TestBasic) { client_ids.insert(NodeID::FromRandom()); for (size_t i = 0; i < oids.size(); i++) { ASSERT_FALSE(pull_manager_.IsObjectActive(oids[i])); - pull_manager_.OnLocationChange(oids[i], client_ids, "", NodeID::Nil(), false, - object_size); + pull_manager_.OnLocationChange( + oids[i], client_ids, "", NodeID::Nil(), false, object_size); } ASSERT_EQ(num_send_pull_request_calls_, oids.size()); ASSERT_EQ(num_restore_spilled_object_calls_, 0); @@ -847,8 +851,8 @@ TEST_P(PullManagerWithAdmissionControlTest, TestQueue) { client_ids.insert(NodeID::FromRandom()); for (auto &oids : bundles) { for (size_t i = 0; i < oids.size(); i++) { - pull_manager_.OnLocationChange(oids[i], client_ids, "", NodeID::Nil(), false, - object_size); + pull_manager_.OnLocationChange( + oids[i], client_ids, "", NodeID::Nil(), false, object_size); } } @@ -895,7 +899,9 @@ TEST_P(PullManagerWithAdmissionControlTest, TestCancel) { /// Test admission control while requests are cancelled out-of-order. When an /// active request is cancelled, we should activate another request in the /// queue, if there is one that satisfies the reported capacity. - auto test_cancel = [&](std::vector object_sizes, int capacity, size_t cancel_idx, + auto test_cancel = [&](std::vector object_sizes, + int capacity, + size_t cancel_idx, int num_active_requests_expected_before, int num_active_requests_expected_after) { pull_manager_.UpdatePullsBasedOnAvailableMemory(capacity); @@ -908,24 +914,28 @@ TEST_P(PullManagerWithAdmissionControlTest, TestCancel) { req_ids.push_back(req_id); } for (size_t i = 0; i < object_sizes.size(); i++) { - pull_manager_.OnLocationChange(oids[i], {}, "", NodeID::Nil(), false, - object_sizes[i]); + pull_manager_.OnLocationChange( + oids[i], {}, "", NodeID::Nil(), false, object_sizes[i]); } AssertNumActiveRequestsEquals(num_active_requests_expected_before); pull_manager_.CancelPull(req_ids[cancel_idx]); AssertNumActiveRequestsEquals(num_active_requests_expected_after); // Request is really canceled. - pull_manager_.OnLocationChange(oids[cancel_idx], {NodeID::FromRandom()}, "", - NodeID::Nil(), false, object_sizes[cancel_idx]); + pull_manager_.OnLocationChange(oids[cancel_idx], + {NodeID::FromRandom()}, + "", + NodeID::Nil(), + false, + object_sizes[cancel_idx]); ASSERT_EQ(num_send_pull_request_calls_, 0); // The expected number of requests at the head of the queue are pulled. int num_active = 0; for (size_t i = 0; i < refs.size() && num_active < num_active_requests_expected_after; i++) { - pull_manager_.OnLocationChange(oids[i], {NodeID::FromRandom()}, "", NodeID::Nil(), - false, object_sizes[i]); + pull_manager_.OnLocationChange( + oids[i], {NodeID::FromRandom()}, "", NodeID::Nil(), false, object_sizes[i]); if (i != cancel_idx) { num_active++; } @@ -981,8 +991,8 @@ TEST_F(PullManagerWithAdmissionControlTest, TestPrioritizeWorkerRequests) { std::unordered_set client_ids; client_ids.insert(NodeID::FromRandom()); for (auto &oid : task_oids) { - pull_manager_.OnLocationChange(oid, client_ids, "", NodeID::Nil(), false, - object_size); + pull_manager_.OnLocationChange( + oid, client_ids, "", NodeID::Nil(), false, object_size); } // Two requests can be pulled at a time. @@ -996,8 +1006,8 @@ TEST_F(PullManagerWithAdmissionControlTest, TestPrioritizeWorkerRequests) { auto wait_req_id = pull_manager_.Pull(refs, BundlePriority::WAIT_REQUEST, &objects_to_locate); wait_oids.push_back(ObjectRefsToIds(refs)[0]); - pull_manager_.OnLocationChange(wait_oids[0], client_ids, "", NodeID::Nil(), false, - object_size); + pull_manager_.OnLocationChange( + wait_oids[0], client_ids, "", NodeID::Nil(), false, object_size); AssertNumActiveRequestsEquals(2); ASSERT_TRUE(pull_manager_.IsObjectActive(wait_oids[0])); ASSERT_TRUE(pull_manager_.IsObjectActive(task_oids[0])); @@ -1017,8 +1027,8 @@ TEST_F(PullManagerWithAdmissionControlTest, TestPrioritizeWorkerRequests) { // Worker request takes priority over the wait and task requests once its size is // available. for (auto &oid : get_oids) { - pull_manager_.OnLocationChange(oid, client_ids, "", NodeID::Nil(), false, - object_size); + pull_manager_.OnLocationChange( + oid, client_ids, "", NodeID::Nil(), false, object_size); } AssertNumActiveRequestsEquals(2); ASSERT_TRUE(pull_manager_.IsObjectActive(get_oids[0])); @@ -1038,8 +1048,8 @@ TEST_F(PullManagerWithAdmissionControlTest, TestPrioritizeWorkerRequests) { ASSERT_FALSE(pull_manager_.IsObjectActive(task_oids[0])); ASSERT_FALSE(pull_manager_.IsObjectActive(task_oids[1])); for (auto &oid : get_oids) { - pull_manager_.OnLocationChange(oid, client_ids, "", NodeID::Nil(), false, - object_size); + pull_manager_.OnLocationChange( + oid, client_ids, "", NodeID::Nil(), false, object_size); } AssertNumActiveRequestsEquals(2); ASSERT_TRUE(pull_manager_.IsObjectActive(get_oids[0])); @@ -1179,10 +1189,12 @@ TEST_P(PullManagerTest, TestTimeOutAfterFailedPull) { AssertNoLeaks(); } -INSTANTIATE_TEST_SUITE_P(WorkerOrTaskRequests, PullManagerTest, +INSTANTIATE_TEST_SUITE_P(WorkerOrTaskRequests, + PullManagerTest, testing::Values(true, false)); -INSTANTIATE_TEST_SUITE_P(WorkerOrTaskRequests, PullManagerWithAdmissionControlTest, +INSTANTIATE_TEST_SUITE_P(WorkerOrTaskRequests, + PullManagerWithAdmissionControlTest, testing::Values(true, false)); } // namespace ray diff --git a/src/ray/object_manager/test/spilled_object_test.cc b/src/ray/object_manager/test/spilled_object_test.cc index ede6f62b9..f9f45e1b5 100644 --- a/src/ray/object_manager/test/spilled_object_test.cc +++ b/src/ray/object_manager/test/spilled_object_test.cc @@ -12,47 +12,49 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ray/object_manager/chunk_object_reader.h" -#include "ray/object_manager/memory_object_reader.h" -#include "ray/object_manager/spilled_object_reader.h" - #include #include #include "absl/strings/str_format.h" #include "gtest/gtest.h" #include "ray/common/test_util.h" +#include "ray/object_manager/chunk_object_reader.h" +#include "ray/object_manager/memory_object_reader.h" +#include "ray/object_manager/spilled_object_reader.h" #include "ray/util/filesystem.h" namespace ray { TEST(SpilledObjectReaderTest, ParseObjectURL) { - auto assert_parse_success = - [](const std::string &object_url, const std::string &expected_file_path, - uint64_t expected_object_offset, uint64_t expected_object_size) { - std::string actual_file_path; - uint64_t actual_offset = 0; - uint64_t actual_size = 0; - ASSERT_TRUE(SpilledObjectReader::ParseObjectURL(object_url, actual_file_path, - actual_offset, actual_size)); - ASSERT_EQ(expected_file_path, actual_file_path); - ASSERT_EQ(expected_object_offset, actual_offset); - ASSERT_EQ(expected_object_size, actual_size); - }; + auto assert_parse_success = [](const std::string &object_url, + const std::string &expected_file_path, + uint64_t expected_object_offset, + uint64_t expected_object_size) { + std::string actual_file_path; + uint64_t actual_offset = 0; + uint64_t actual_size = 0; + ASSERT_TRUE(SpilledObjectReader::ParseObjectURL( + object_url, actual_file_path, actual_offset, actual_size)); + ASSERT_EQ(expected_file_path, actual_file_path); + ASSERT_EQ(expected_object_offset, actual_offset); + ASSERT_EQ(expected_object_size, actual_size); + }; auto assert_parse_fail = [](const std::string &object_url) { std::string actual_file_path; uint64_t actual_offset = 0; uint64_t actual_size = 0; - ASSERT_FALSE(SpilledObjectReader::ParseObjectURL(object_url, actual_file_path, - actual_offset, actual_size)); + ASSERT_FALSE(SpilledObjectReader::ParseObjectURL( + object_url, actual_file_path, actual_offset, actual_size)); }; - assert_parse_success("file://path/to/file?offset=123&size=456", "file://path/to/file", - 123, 456); + assert_parse_success( + "file://path/to/file?offset=123&size=456", "file://path/to/file", 123, 456); assert_parse_success("http://123?offset=123&size=456", "http://123", 123, 456); assert_parse_success("file:///C:/Users/file.txt?offset=123&size=456", - "file:///C:/Users/file.txt", 123, 456); + "file:///C:/Users/file.txt", + 123, + 456); assert_parse_success("/tmp/file.txt?offset=123&size=456", "/tmp/file.txt", 123, 456); assert_parse_success("C:\\file.txt?offset=123&size=456", "C:\\file.txt", 123, 456); assert_parse_success( @@ -61,9 +63,10 @@ TEST(SpilledObjectReaderTest, ParseObjectURL) { "2199437144", "/tmp/ray/session_2021-07-19_09-50-58_115365_119/ray_spillled_objects/" "2f81e7cfcc578f4effffffffffffffffffffffff0200000001000000-multi-1", - 0, 2199437144); - assert_parse_success("/tmp/123?offset=0&size=9223372036854775807", "/tmp/123", 0, - 9223372036854775807); + 0, + 2199437144); + assert_parse_success( + "/tmp/123?offset=0&size=9223372036854775807", "/tmp/123", 0, 9223372036854775807); assert_parse_fail("/tmp/123?offset=-1&size=1"); assert_parse_fail("/tmp/123?offset=0&size=9223372036854775808"); @@ -74,10 +77,12 @@ TEST(SpilledObjectReaderTest, ParseObjectURL) { } TEST(SpilledObjectReaderTest, ToUINT64) { - ASSERT_EQ(0, SpilledObjectReader::ToUINT64( - {'\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00'})); - ASSERT_EQ(1, SpilledObjectReader::ToUINT64( - {'\x01', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00'})); + ASSERT_EQ(0, + SpilledObjectReader::ToUINT64( + {'\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00'})); + ASSERT_EQ(1, + SpilledObjectReader::ToUINT64( + {'\x01', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00'})); ASSERT_EQ(std::numeric_limits::max(), SpilledObjectReader::ToUINT64( {'\xff', '\xff', '\xff', '\xff', '\xff', '\xff', '\xff', '\xff'})); @@ -104,8 +109,10 @@ TEST(SpilledObjectReaderTest, ReadUINT64) { } namespace { -std::string ContructObjectString(uint64_t object_offset, std::string data, - std::string metadata, rpc::Address owner_address) { +std::string ContructObjectString(uint64_t object_offset, + std::string data, + std::string metadata, + rpc::Address owner_address) { std::string result(object_offset, '\0'); std::string address_str; owner_address.SerializeToString(&address_str); @@ -124,8 +131,10 @@ std::string ContructObjectString(uint64_t object_offset, std::string data, } // namespace TEST(SpilledObjectReaderTest, ParseObjectHeader) { - auto assert_parse_success = [](uint64_t object_offset, std::string data, - std::string metadata, std::string raylet_id) { + auto assert_parse_success = [](uint64_t object_offset, + std::string data, + std::string metadata, + std::string raylet_id) { rpc::Address owner_address; owner_address.set_raylet_id(raylet_id); auto str = ContructObjectString(object_offset, data, metadata, owner_address); @@ -135,9 +144,13 @@ TEST(SpilledObjectReaderTest, ParseObjectHeader) { uint64_t actual_metadata_size = 0; rpc::Address actual_owner_address; std::istringstream is(str); - ASSERT_TRUE(SpilledObjectReader::ParseObjectHeader( - is, object_offset, actual_data_offset, actual_data_size, actual_metadata_offset, - actual_metadata_size, actual_owner_address)); + ASSERT_TRUE(SpilledObjectReader::ParseObjectHeader(is, + object_offset, + actual_data_offset, + actual_data_size, + actual_metadata_offset, + actual_metadata_size, + actual_owner_address)); std::string address_str; owner_address.SerializeToString(&address_str); ASSERT_EQ(object_offset + 24 + address_str.size(), actual_metadata_offset); @@ -178,9 +191,13 @@ TEST(SpilledObjectReaderTest, ParseObjectHeader) { uint64_t actual_metadata_size = 0; rpc::Address actual_owner_address; std::istringstream is(str); - ASSERT_FALSE(SpilledObjectReader::ParseObjectHeader( - is, object_offset, actual_data_offset, actual_data_size, actual_metadata_offset, - actual_metadata_size, actual_owner_address)); + ASSERT_FALSE(SpilledObjectReader::ParseObjectHeader(is, + object_offset, + actual_data_offset, + actual_data_size, + actual_metadata_offset, + actual_metadata_size, + actual_owner_address)); }; std::string address_str; @@ -193,7 +210,8 @@ TEST(SpilledObjectReaderTest, ParseObjectHeader) { } namespace { -std::string CreateSpilledObjectReaderOnTmp(uint64_t object_offset, std::string data, +std::string CreateSpilledObjectReaderOnTmp(uint64_t object_offset, + std::string data, std::string metadata, rpc::Address owner_address, bool skip_write = false) { @@ -206,11 +224,12 @@ std::string CreateSpilledObjectReaderOnTmp(uint64_t object_offset, std::string d RAY_CHECK(f.write(str.c_str(), str.size())); } f.close(); - return absl::StrFormat("%s?offset=%d&size=%d", tmp_file, object_offset, - str.size() - object_offset); + return absl::StrFormat( + "%s?offset=%d&size=%d", tmp_file, object_offset, str.size() - object_offset); } -MemoryObjectReader CreateMemoryObjectReader(std::string &data, std::string &metadata, +MemoryObjectReader CreateMemoryObjectReader(std::string &data, + std::string &metadata, rpc::Address owner_address) { plasma::ObjectBuffer object_buffer; object_buffer.data = @@ -223,59 +242,66 @@ MemoryObjectReader CreateMemoryObjectReader(std::string &data, std::string &meta } // namespace TEST(ChunkObjectReaderTest, GetNumChunks) { - auto assert_get_num_chunks = [](uint64_t data_size, uint64_t chunk_size, - uint64_t expected_num_chunks) { - rpc::Address owner_address; - owner_address.set_raylet_id("nonsense"); - ChunkObjectReader reader( - std::make_shared(SpilledObjectReader( - "path", 100 /* object_size */, 2 /* data_offset */, data_size /* data_size */, - 4 /* metadata_offset */, 0 /* metadata_size */, owner_address)), - chunk_size /* chunk_size */); + auto assert_get_num_chunks = + [](uint64_t data_size, uint64_t chunk_size, uint64_t expected_num_chunks) { + rpc::Address owner_address; + owner_address.set_raylet_id("nonsense"); + ChunkObjectReader reader(std::make_shared( + SpilledObjectReader("path", + 100 /* object_size */, + 2 /* data_offset */, + data_size /* data_size */, + 4 /* metadata_offset */, + 0 /* metadata_size */, + owner_address)), + chunk_size /* chunk_size */); - ASSERT_EQ(expected_num_chunks, reader.GetNumChunks()); - ASSERT_EQ(expected_num_chunks, reader.GetNumChunks()); - }; + ASSERT_EQ(expected_num_chunks, reader.GetNumChunks()); + ASSERT_EQ(expected_num_chunks, reader.GetNumChunks()); + }; - assert_get_num_chunks(11 /* data_size */, 1 /* chunk_size */, - 11 /* expected_num_chunks */); - assert_get_num_chunks(1 /* data_size */, 11 /* chunk_size */, - 1 /* expected_num_chunks */); - assert_get_num_chunks(0 /* data_size */, 11 /* chunk_size */, - 0 /* expected_num_chunks */); - assert_get_num_chunks(9 /* data_size */, 2 /* chunk_size */, - 5 /* expected_num_chunks */); - assert_get_num_chunks(10 /* data_size */, 2 /* chunk_size */, - 5 /* expected_num_chunks */); - assert_get_num_chunks(11 /* data_size */, 2 /* chunk_size */, - 6 /* expected_num_chunks */); + assert_get_num_chunks( + 11 /* data_size */, 1 /* chunk_size */, 11 /* expected_num_chunks */); + assert_get_num_chunks( + 1 /* data_size */, 11 /* chunk_size */, 1 /* expected_num_chunks */); + assert_get_num_chunks( + 0 /* data_size */, 11 /* chunk_size */, 0 /* expected_num_chunks */); + assert_get_num_chunks( + 9 /* data_size */, 2 /* chunk_size */, 5 /* expected_num_chunks */); + assert_get_num_chunks( + 10 /* data_size */, 2 /* chunk_size */, 5 /* expected_num_chunks */); + assert_get_num_chunks( + 11 /* data_size */, 2 /* chunk_size */, 6 /* expected_num_chunks */); } TEST(SpilledObjectReaderTest, CreateSpilledObjectReader) { - auto object_url = CreateSpilledObjectReaderOnTmp(10 /* object_offset */, "data", - "metadata", ray::rpc::Address()); + auto object_url = CreateSpilledObjectReaderOnTmp( + 10 /* object_offset */, "data", "metadata", ray::rpc::Address()); ASSERT_TRUE(SpilledObjectReader::CreateSpilledObjectReader(object_url).has_value()); ASSERT_FALSE( SpilledObjectReader::CreateSpilledObjectReader("malformatted_url").has_value()); auto optional_object = SpilledObjectReader::CreateSpilledObjectReader(object_url); ASSERT_TRUE(optional_object.has_value()); - auto object_url1 = - CreateSpilledObjectReaderOnTmp(10 /* object_offset */, "data", "metadata", - ray::rpc::Address(), true /* skip_write */); + auto object_url1 = CreateSpilledObjectReaderOnTmp(10 /* object_offset */, + "data", + "metadata", + ray::rpc::Address(), + true /* skip_write */); // file corrupted. ASSERT_FALSE(SpilledObjectReader::CreateSpilledObjectReader(object_url1).has_value()); } template -std::shared_ptr CreateObjectReader(std::string &data, std::string &metadata, +std::shared_ptr CreateObjectReader(std::string &data, + std::string &metadata, rpc::Address owner_address); template <> std::shared_ptr CreateObjectReader( std::string &data, std::string &metadata, rpc::Address owner_address) { - auto object_url = CreateSpilledObjectReaderOnTmp(0 /* object_offset */, data, metadata, - owner_address); + auto object_url = CreateSpilledObjectReaderOnTmp( + 0 /* object_offset */, data, metadata, owner_address); auto optional_object = SpilledObjectReader::CreateSpilledObjectReader(object_url); return std::make_shared(std::move(optional_object.value())); } @@ -289,7 +315,8 @@ std::shared_ptr CreateObjectReader( template struct ObjectReaderTest : public ::testing::Test { - static std::shared_ptr CreateObjectReader_(std::string &data, std::string &metadata, + static std::shared_ptr CreateObjectReader_(std::string &data, + std::string &metadata, rpc::Address owner_address) { return CreateObjectReader(data, metadata, owner_address); } diff --git a/src/ray/pubsub/mock_pubsub.h b/src/ray/pubsub/mock_pubsub.h index 2265883fd..83dec35f7 100644 --- a/src/ray/pubsub/mock_pubsub.h +++ b/src/ray/pubsub/mock_pubsub.h @@ -29,7 +29,8 @@ class MockSubscriber : public pubsub::SubscriberInterface { MOCK_METHOD7(Subscribe, bool(std::unique_ptr sub_message, const rpc::ChannelType channel_type, - const rpc::Address &owner_address, const std::string &key_id, + const rpc::Address &owner_address, + const std::string &key_id, pubsub::SubscribeDoneCallback subscribe_done_callback, pubsub::SubscriptionItemCallback subscription_callback, pubsub::SubscriptionFailureCallback subscription_failure_callback)); @@ -44,29 +45,34 @@ class MockSubscriber : public pubsub::SubscriberInterface { MOCK_METHOD3(Unsubscribe, bool(const rpc::ChannelType channel_type, - const rpc::Address &publisher_address, const std::string &key_id)); + const rpc::Address &publisher_address, + const std::string &key_id)); - MOCK_METHOD2(UnsubscribeChannel, bool(const rpc::ChannelType channel_type, - const rpc::Address &publisher_address)); + MOCK_METHOD2(UnsubscribeChannel, + bool(const rpc::ChannelType channel_type, + const rpc::Address &publisher_address)); - MOCK_CONST_METHOD3(IsSubscribed, bool(const rpc::ChannelType channel_type, - const rpc::Address &publisher_address, - const std::string &key_id)); + MOCK_CONST_METHOD3(IsSubscribed, + bool(const rpc::ChannelType channel_type, + const rpc::Address &publisher_address, + const std::string &key_id)); MOCK_CONST_METHOD0(DebugString, std::string()); }; class MockPublisher : public pubsub::PublisherInterface { public: - MOCK_METHOD3(RegisterSubscription, bool(const rpc::ChannelType channel_type, - const pubsub::SubscriberID &subscriber_id, - const std::optional &key_id)); + MOCK_METHOD3(RegisterSubscription, + bool(const rpc::ChannelType channel_type, + const pubsub::SubscriberID &subscriber_id, + const std::optional &key_id)); MOCK_METHOD1(Publish, void(const rpc::PubMessage &pub_message)); - MOCK_METHOD3(UnregisterSubscription, bool(const rpc::ChannelType channel_type, - const pubsub::SubscriberID &subscriber_id, - const std::optional &key_id)); + MOCK_METHOD3(UnregisterSubscription, + bool(const rpc::ChannelType channel_type, + const pubsub::SubscriberID &subscriber_id, + const std::optional &key_id)); MOCK_METHOD2(PublishFailure, void(const rpc::ChannelType channel_type, const std::string &key_id)); diff --git a/src/ray/pubsub/publisher.cc b/src/ray/pubsub/publisher.cc index 4440e08c8..6111389e3 100644 --- a/src/ray/pubsub/publisher.cc +++ b/src/ray/pubsub/publisher.cc @@ -39,8 +39,8 @@ std::vector SubscriptionIndex::GetSubscriberIdsByKeyId( const std::string &key_id) const { std::vector subscribers; if (!subscribers_to_all_.empty()) { - subscribers.insert(subscribers.end(), subscribers_to_all_.begin(), - subscribers_to_all_.end()); + subscribers.insert( + subscribers.end(), subscribers_to_all_.begin(), subscribers_to_all_.end()); } auto it = key_id_to_subscribers_.find(key_id); if (it != key_id_to_subscribers_.end()) { @@ -217,9 +217,12 @@ void Publisher::ConnectToSubscriber(const rpc::PubsubLongPollingRequest &request auto it = subscribers_.find(subscriber_id); if (it == subscribers_.end()) { it = subscribers_ - .emplace(subscriber_id, std::make_shared( - subscriber_id, get_time_ms_, - subscriber_timeout_ms_, publish_batch_size_)) + .emplace( + subscriber_id, + std::make_shared(subscriber_id, + get_time_ms_, + subscriber_timeout_ms_, + publish_batch_size_)) .first; } auto &subscriber = it->second; @@ -233,9 +236,10 @@ bool Publisher::RegisterSubscription(const rpc::ChannelType channel_type, const std::optional &key_id) { absl::MutexLock lock(&mutex_); if (!subscribers_.contains(subscriber_id)) { - subscribers_.emplace(subscriber_id, std::make_shared( - subscriber_id, get_time_ms_, - subscriber_timeout_ms_, publish_batch_size_)); + subscribers_.emplace( + subscriber_id, + std::make_shared( + subscriber_id, get_time_ms_, subscriber_timeout_ms_, publish_batch_size_)); } auto subscription_index_it = subscription_index_map_.find(channel_type); RAY_CHECK(subscription_index_it != subscription_index_map_.end()); diff --git a/src/ray/pubsub/publisher.h b/src/ray/pubsub/publisher.h index a7d372608..f7300274e 100644 --- a/src/ray/pubsub/publisher.h +++ b/src/ray/pubsub/publisher.h @@ -98,8 +98,10 @@ struct LongPollConnection { /// Keeps the state of each connected subscriber. class SubscriberState { public: - SubscriberState(SubscriberID subscriber_id, std::function get_time_ms, - uint64_t connection_timeout_ms, const int publish_batch_size) + SubscriberState(SubscriberID subscriber_id, + std::function get_time_ms, + uint64_t connection_timeout_ms, + const int publish_batch_size) : subscriber_id_(subscriber_id), get_time_ms_(std::move(get_time_ms)), connection_timeout_ms_(connection_timeout_ms), @@ -232,7 +234,8 @@ class Publisher : public PublisherInterface { /// \param publish_batch_size The batch size of published messages. Publisher(const std::vector &channels, PeriodicalRunner *const periodical_runner, - std::function get_time_ms, const uint64_t subscriber_timeout_ms, + std::function get_time_ms, + const uint64_t subscriber_timeout_ms, const int publish_batch_size) : periodical_runner_(periodical_runner), get_time_ms_(std::move(get_time_ms)), diff --git a/src/ray/pubsub/subscriber.cc b/src/ray/pubsub/subscriber.cc index 9fcad0aef..f36fbabf5 100644 --- a/src/ray/pubsub/subscriber.cc +++ b/src/ray/pubsub/subscriber.cc @@ -23,7 +23,8 @@ namespace pubsub { /////////////////////////////////////////////////////////////////////////////// bool SubscriberChannel::Subscribe( - const rpc::Address &publisher_address, const std::optional &key_id, + const rpc::Address &publisher_address, + const std::optional &key_id, SubscriptionItemCallback subscription_callback, SubscriptionFailureCallback subscription_failure_callback) { cum_subscribe_requests_++; @@ -32,8 +33,9 @@ bool SubscriberChannel::Subscribe( if (key_id) { return subscription_map_[publisher_id] .per_entity_subscription - .try_emplace(*key_id, SubscriptionInfo(std::move(subscription_callback), - std::move(subscription_failure_callback))) + .try_emplace(*key_id, + SubscriptionInfo(std::move(subscription_callback), + std::move(subscription_failure_callback))) .second; } auto &all_entities_subscription = @@ -187,15 +189,18 @@ void SubscriberChannel::HandlePublisherFailure(const rpc::Address &publisher_add } bool SubscriberChannel::HandlePublisherFailureInternal( - const rpc::Address &publisher_address, const std::string &key_id, + const rpc::Address &publisher_address, + const std::string &key_id, const Status &status) { auto maybe_failure_callback = GetFailureCallback(publisher_address, key_id); if (maybe_failure_callback.has_value()) { const auto &channel_name = rpc::ChannelType_descriptor()->FindValueByNumber(channel_type_)->name(); - callback_service_->post([failure_callback = std::move(maybe_failure_callback.value()), - key_id, status]() { failure_callback(key_id, status); }, - "Subscriber.HandleFailureCallback_" + channel_name); + callback_service_->post( + [failure_callback = std::move(maybe_failure_callback.value()), key_id, status]() { + failure_callback(key_id, status); + }, + "Subscriber.HandleFailureCallback_" + channel_name); return true; } return false; @@ -229,19 +234,27 @@ bool Subscriber::Subscribe(std::unique_ptr sub_message, SubscribeDoneCallback subscribe_done_callback, SubscriptionItemCallback subscription_callback, SubscriptionFailureCallback subscription_failure_callback) { - return SubscribeInternal(std::move(sub_message), channel_type, publisher_address, - key_id, std::move(subscribe_done_callback), + return SubscribeInternal(std::move(sub_message), + channel_type, + publisher_address, + key_id, + std::move(subscribe_done_callback), std::move(subscription_callback), std::move(subscription_failure_callback)); } bool Subscriber::SubscribeChannel( - std::unique_ptr sub_message, const rpc::ChannelType channel_type, - const rpc::Address &publisher_address, SubscribeDoneCallback subscribe_done_callback, + std::unique_ptr sub_message, + const rpc::ChannelType channel_type, + const rpc::Address &publisher_address, + SubscribeDoneCallback subscribe_done_callback, SubscriptionItemCallback subscription_callback, SubscriptionFailureCallback subscription_failure_callback) { - return SubscribeInternal(std::move(sub_message), channel_type, publisher_address, - std::nullopt, std::move(subscribe_done_callback), + return SubscribeInternal(std::move(sub_message), + channel_type, + publisher_address, + std::nullopt, + std::move(subscribe_done_callback), std::move(subscription_callback), std::move(subscription_failure_callback)); } @@ -290,8 +303,10 @@ bool Subscriber::IsSubscribed(const rpc::ChannelType channel_type, } bool Subscriber::SubscribeInternal( - std::unique_ptr sub_message, const rpc::ChannelType channel_type, - const rpc::Address &publisher_address, const std::optional &key_id, + std::unique_ptr sub_message, + const rpc::ChannelType channel_type, + const rpc::Address &publisher_address, + const std::optional &key_id, SubscribeDoneCallback subscribe_done_callback, SubscriptionItemCallback subscription_callback, SubscriptionFailureCallback subscription_failure_callback) { @@ -312,7 +327,9 @@ bool Subscriber::SubscribeInternal( SendCommandBatchIfPossible(publisher_address); MakeLongPollingConnectionIfNotConnected(publisher_address); return Channel(channel_type) - ->Subscribe(publisher_address, key_id, std::move(subscription_callback), + ->Subscribe(publisher_address, + key_id, + std::move(subscription_callback), std::move(subscription_failure_callback)); } diff --git a/src/ray/pubsub/subscriber.h b/src/ray/pubsub/subscriber.h index f495f0b2d..679e344f9 100644 --- a/src/ray/pubsub/subscriber.h +++ b/src/ray/pubsub/subscriber.h @@ -146,7 +146,8 @@ class SubscriberChannel { /// id. \return Return true if the given key id needs to be unsubscribed. False /// otherwise. bool HandlePublisherFailureInternal(const rpc::Address &publisher_address, - const std::string &key_id, const Status &status); + const std::string &key_id, + const Status &status); /// Returns a subscription callback; Returns a nullopt if the object id is not /// subscribed. @@ -231,8 +232,10 @@ class SubscriberInterface { /// \param subscription_failure_callback A callback that is invoked whenever the /// connection to publisher is broken (e.g. the publisher fails). [[nodiscard]] virtual bool Subscribe( - std::unique_ptr sub_message, rpc::ChannelType channel_type, - const rpc::Address &publisher_address, const std::string &key_id, + std::unique_ptr sub_message, + rpc::ChannelType channel_type, + const rpc::Address &publisher_address, + const std::string &key_id, SubscribeDoneCallback subscribe_done_callback, SubscriptionItemCallback subscription_callback, SubscriptionFailureCallback subscription_failure_callback) = 0; @@ -247,7 +250,8 @@ class SubscriberInterface { /// \param subscription_failure_callback A callback that is invoked whenever the /// connection to publisher is broken (e.g. the publisher fails). [[nodiscard]] virtual bool SubscribeChannel( - std::unique_ptr sub_message, rpc::ChannelType channel_type, + std::unique_ptr sub_message, + rpc::ChannelType channel_type, const rpc::Address &publisher_address, SubscribeDoneCallback subscribe_done_callback, SubscriptionItemCallback subscription_callback, @@ -322,7 +326,8 @@ class SubscriberClientInterface { class Subscriber : public SubscriberInterface { public: Subscriber( - const SubscriberID subscriber_id, const std::vector &channels, + const SubscriberID subscriber_id, + const std::vector &channels, const int64_t max_command_batch_size, std::function(const rpc::Address &)> get_client, @@ -340,13 +345,15 @@ class Subscriber : public SubscriberInterface { bool Subscribe(std::unique_ptr sub_message, const rpc::ChannelType channel_type, - const rpc::Address &publisher_address, const std::string &key_id, + const rpc::Address &publisher_address, + const std::string &key_id, SubscribeDoneCallback subscribe_done_callback, SubscriptionItemCallback subscription_callback, SubscriptionFailureCallback subscription_failure_callback) override; bool SubscribeChannel( - std::unique_ptr sub_message, rpc::ChannelType channel_type, + std::unique_ptr sub_message, + rpc::ChannelType channel_type, const rpc::Address &publisher_address, SubscribeDoneCallback subscribe_done_callback, SubscriptionItemCallback subscription_callback, diff --git a/src/ray/pubsub/test/integration_test.cc b/src/ray/pubsub/test/integration_test.cc index 7e95bf330..9f52c6e7a 100644 --- a/src/ray/pubsub/test/integration_test.cc +++ b/src/ray/pubsub/test/integration_test.cc @@ -39,36 +39,41 @@ class SubscriberServiceImpl final : public rpc::SubscriberService::CallbackServi : publisher_(std::move(publisher)) {} grpc::ServerUnaryReactor *PubsubLongPolling( - grpc::CallbackServerContext *context, const rpc::PubsubLongPollingRequest *request, + grpc::CallbackServerContext *context, + const rpc::PubsubLongPollingRequest *request, rpc::PubsubLongPollingReply *reply) override { auto *reactor = context->DefaultReactor(); - publisher_->ConnectToSubscriber( - *request, reply, - [reactor](ray::Status status, std::function success_cb, - std::function failure_cb) { - // Long polling should always succeed. - RAY_CHECK_OK(status); - RAY_CHECK(success_cb == nullptr); - RAY_CHECK(failure_cb == nullptr); - reactor->Finish(grpc::Status::OK); - }); + publisher_->ConnectToSubscriber(*request, + reply, + [reactor](ray::Status status, + std::function success_cb, + std::function failure_cb) { + // Long polling should always succeed. + RAY_CHECK_OK(status); + RAY_CHECK(success_cb == nullptr); + RAY_CHECK(failure_cb == nullptr); + reactor->Finish(grpc::Status::OK); + }); return reactor; } // For simplicity, all work is done on the GRPC thread. grpc::ServerUnaryReactor *PubsubCommandBatch( - grpc::CallbackServerContext *context, const rpc::PubsubCommandBatchRequest *request, + grpc::CallbackServerContext *context, + const rpc::PubsubCommandBatchRequest *request, rpc::PubsubCommandBatchReply *reply) override { const auto subscriber_id = UniqueID::FromBinary(request->subscriber_id()); auto *reactor = context->DefaultReactor(); for (const auto &command : request->commands()) { if (command.has_unsubscribe_message()) { - publisher_->UnregisterSubscription(command.channel_type(), subscriber_id, + publisher_->UnregisterSubscription(command.channel_type(), + subscriber_id, command.key_id().empty() ? std::nullopt : std::make_optional(command.key_id())); } else if (command.has_subscribe_message()) { - publisher_->RegisterSubscription(command.channel_type(), subscriber_id, + publisher_->RegisterSubscription(command.channel_type(), + subscriber_id, command.key_id().empty() ? std::nullopt : std::make_optional(command.key_id())); @@ -104,12 +109,12 @@ class CallbackSubscriberClient final : public pubsub::SubscriberClientInterface const rpc::ClientCallback &callback) final { auto *context = new grpc::ClientContext; auto *reply = new rpc::PubsubLongPollingReply; - stub_->async()->PubsubLongPolling(context, &request, reply, - [callback, context, reply](grpc::Status s) { - callback(GrpcStatusToRayStatus(s), *reply); - delete reply; - delete context; - }); + stub_->async()->PubsubLongPolling( + context, &request, reply, [callback, context, reply](grpc::Status s) { + callback(GrpcStatusToRayStatus(s), *reply); + delete reply; + delete context; + }); } void PubsubCommandBatch( @@ -117,12 +122,12 @@ class CallbackSubscriberClient final : public pubsub::SubscriberClientInterface const rpc::ClientCallback &callback) final { auto *context = new grpc::ClientContext; auto *reply = new rpc::PubsubCommandBatchReply; - stub_->async()->PubsubCommandBatch(context, &request, reply, - [callback, context, reply](grpc::Status s) { - callback(GrpcStatusToRayStatus(s), *reply); - delete reply; - delete context; - }); + stub_->async()->PubsubCommandBatch( + context, &request, reply, [callback, context, reply](grpc::Status s) { + callback(GrpcStatusToRayStatus(s), *reply); + delete reply; + delete context; + }); } private: @@ -207,8 +212,10 @@ TEST_F(IntegrationTest, SubscribersToOneIDAndAllIDs) { std::vector actors_1; auto subscriber_1 = CreateSubscriber(); subscriber_1->Subscribe( - std::make_unique(), rpc::ChannelType::GCS_ACTOR_CHANNEL, - address_proto_, subscribed_actor, + std::make_unique(), + rpc::ChannelType::GCS_ACTOR_CHANNEL, + address_proto_, + subscribed_actor, /*subscribe_done_callback=*/ [&counter](Status status) { RAY_CHECK_OK(status); @@ -225,7 +232,8 @@ TEST_F(IntegrationTest, SubscribersToOneIDAndAllIDs) { std::vector actors_2; auto subscriber_2 = CreateSubscriber(); subscriber_2->SubscribeChannel( - std::make_unique(), rpc::ChannelType::GCS_ACTOR_CHANNEL, + std::make_unique(), + rpc::ChannelType::GCS_ACTOR_CHANNEL, address_proto_, /*subscribe_done_callback=*/ [&counter](Status status) { @@ -275,8 +283,8 @@ TEST_F(IntegrationTest, SubscribersToOneIDAndAllIDs) { EXPECT_EQ(actors_1[0].actor_id(), actor_data.actor_id()); EXPECT_EQ(actors_2[0].actor_id(), actor_data.actor_id()); - subscriber_1->Unsubscribe(rpc::ChannelType::GCS_ACTOR_CHANNEL, address_proto_, - subscribed_actor); + subscriber_1->Unsubscribe( + rpc::ChannelType::GCS_ACTOR_CHANNEL, address_proto_, subscribed_actor); subscriber_2->UnsubscribeChannel(rpc::ChannelType::GCS_ACTOR_CHANNEL, address_proto_); // Flush all the inflight long polling. diff --git a/src/ray/pubsub/test/publisher_test.cc b/src/ray/pubsub/test/publisher_test.cc index 10895632c..caaf70c52 100644 --- a/src/ray/pubsub/test/publisher_test.cc +++ b/src/ray/pubsub/test/publisher_test.cc @@ -290,8 +290,8 @@ TEST_F(PublisherTest, TestSubscriber) { absl::flat_hash_set object_ids_published; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply, &object_ids_published](Status status, std::function success, - std::function failure) { + [&reply, &object_ids_published]( + Status status, std::function success, std::function failure) { for (int i = 0; i < reply.pub_messages_size(); i++) { const auto &msg = reply.pub_messages(i); const auto oid = @@ -344,8 +344,8 @@ TEST_F(PublisherTest, TestSubscriberBatchSize) { absl::flat_hash_set object_ids_published; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply, &object_ids_published](Status status, std::function success, - std::function failure) { + [&reply, &object_ids_published]( + Status status, std::function success, std::function failure) { for (int i = 0; i < reply.pub_messages_size(); i++) { const auto &msg = reply.pub_messages(i); const auto oid = @@ -357,7 +357,9 @@ TEST_F(PublisherTest, TestSubscriberBatchSize) { auto max_publish_size = 5; auto subscriber = std::make_shared( - subscriber_id_, [this]() { return current_time_; }, subscriber_timeout_ms_, + subscriber_id_, + [this]() { return current_time_; }, + subscriber_timeout_ms_, max_publish_size); subscriber->ConnectToSubscriber(request_, &reply, send_reply_callback); @@ -396,7 +398,8 @@ TEST_F(PublisherTest, TestSubscriberActiveTimeout) { auto reply_cnt = 0; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply_cnt](Status status, std::function success, + [&reply_cnt](Status status, + std::function success, std::function failure) { reply_cnt++; }; auto subscriber = std::make_shared( @@ -456,7 +459,8 @@ TEST_F(PublisherTest, TestSubscriberDisconnected) { auto reply_cnt = 0; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply_cnt](Status status, std::function success, + [&reply_cnt](Status status, + std::function success, std::function failure) { reply_cnt++; }; auto subscriber = std::make_shared( @@ -515,7 +519,8 @@ TEST_F(PublisherTest, TestSubscriberTimeoutComplicated) { auto reply_cnt = 0; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply_cnt](Status status, std::function success, + [&reply_cnt](Status status, + std::function success, std::function failure) { reply_cnt++; }; auto subscriber = std::make_shared( @@ -557,8 +562,8 @@ TEST_F(PublisherTest, TestBasicSingleSubscriber) { std::vector batched_ids; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply, &batched_ids](Status status, std::function success, - std::function failure) { + [&reply, &batched_ids]( + Status status, std::function success, std::function failure) { for (int i = 0; i < reply.pub_messages_size(); i++) { const auto &msg = reply.pub_messages(i); const auto oid = @@ -571,8 +576,8 @@ TEST_F(PublisherTest, TestBasicSingleSubscriber) { const auto oid = ObjectID::FromRandom(); publisher_->ConnectToSubscriber(request_, &reply, send_reply_callback); - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); publisher_->Publish(GeneratePubMessage(oid)); ASSERT_EQ(batched_ids[0], oid); } @@ -581,8 +586,8 @@ TEST_F(PublisherTest, TestNoConnectionWhenRegistered) { std::vector batched_ids; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply, &batched_ids](Status status, std::function success, - std::function failure) { + [&reply, &batched_ids]( + Status status, std::function success, std::function failure) { for (int i = 0; i < reply.pub_messages_size(); i++) { const auto &msg = reply.pub_messages(i); const auto oid = @@ -594,8 +599,8 @@ TEST_F(PublisherTest, TestNoConnectionWhenRegistered) { const auto oid = ObjectID::FromRandom(); - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); publisher_->Publish(GeneratePubMessage(oid)); // Nothing has been published because there's no connection. ASSERT_EQ(batched_ids.size(), 0); @@ -608,8 +613,8 @@ TEST_F(PublisherTest, TestMultiObjectsFromSingleNode) { std::vector batched_ids; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply, &batched_ids](Status status, std::function success, - std::function failure) { + [&reply, &batched_ids]( + Status status, std::function success, std::function failure) { for (int i = 0; i < reply.pub_messages_size(); i++) { const auto &msg = reply.pub_messages(i); const auto oid = @@ -624,8 +629,8 @@ TEST_F(PublisherTest, TestMultiObjectsFromSingleNode) { for (int i = 0; i < num_oids; i++) { const auto oid = ObjectID::FromRandom(); oids.push_back(oid); - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); publisher_->Publish(GeneratePubMessage(oid)); } ASSERT_EQ(batched_ids.size(), 0); @@ -643,8 +648,8 @@ TEST_F(PublisherTest, TestMultiObjectsFromMultiNodes) { std::vector batched_ids; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply, &batched_ids](Status status, std::function success, - std::function failure) { + [&reply, &batched_ids]( + Status status, std::function success, std::function failure) { for (int i = 0; i < reply.pub_messages_size(); i++) { const auto &msg = reply.pub_messages(i); const auto oid = @@ -665,8 +670,8 @@ TEST_F(PublisherTest, TestMultiObjectsFromMultiNodes) { // There will be one object per node. for (int i = 0; i < num_nodes; i++) { const auto oid = oids[i]; - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); publisher_->Publish(GeneratePubMessage(oid)); } ASSERT_EQ(batched_ids.size(), 0); @@ -685,8 +690,8 @@ TEST_F(PublisherTest, TestMultiSubscribers) { rpc::PubsubLongPollingReply reply; int reply_invoked = 0; rpc::SendReplyCallback send_reply_callback = - [&reply, &batched_ids, &reply_invoked](Status status, std::function success, - std::function failure) { + [&reply, &batched_ids, &reply_invoked]( + Status status, std::function success, std::function failure) { for (int i = 0; i < reply.pub_messages_size(); i++) { const auto &msg = reply.pub_messages(i); const auto oid = @@ -706,8 +711,8 @@ TEST_F(PublisherTest, TestMultiSubscribers) { // There will be one object per node. for (int i = 0; i < num_nodes; i++) { - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); } ASSERT_EQ(batched_ids.size(), 0); @@ -725,8 +730,8 @@ TEST_F(PublisherTest, TestBatch) { std::vector batched_ids; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply, &batched_ids](Status status, std::function success, - std::function failure) { + [&reply, &batched_ids]( + Status status, std::function success, std::function failure) { for (int i = 0; i < reply.pub_messages_size(); i++) { const auto &msg = reply.pub_messages(i); const auto oid = @@ -741,8 +746,8 @@ TEST_F(PublisherTest, TestBatch) { for (int i = 0; i < num_oids; i++) { const auto oid = ObjectID::FromRandom(); oids.push_back(oid); - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); publisher_->Publish(GeneratePubMessage(oid)); } ASSERT_EQ(batched_ids.size(), 0); @@ -761,8 +766,8 @@ TEST_F(PublisherTest, TestBatch) { for (int i = 0; i < num_oids; i++) { const auto oid = ObjectID::FromRandom(); oids.push_back(oid); - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); publisher_->Publish(GeneratePubMessage(oid)); } publisher_->ConnectToSubscriber(request_, &reply, send_reply_callback); @@ -777,16 +782,16 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionExisted) { bool long_polling_connection_replied = false; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&long_polling_connection_replied](Status status, std::function success, - std::function failure) { + [&long_polling_connection_replied]( + Status status, std::function success, std::function failure) { long_polling_connection_replied = true; }; const auto oid = ObjectID::FromRandom(); publisher_->ConnectToSubscriber(request_, &reply, send_reply_callback); // This information should be cleaned up as the subscriber is dead. - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); // Timeout is reached. The connection should've been refreshed. Since the subscriber is // dead, no new connection is made. current_time_ += subscriber_timeout_ms_; @@ -806,8 +811,8 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionExisted) { // New subscriber is registsered for some reason. Since there's no new long polling // connection for the timeout, it should be removed. long_polling_connection_replied = false; - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); current_time_ += subscriber_timeout_ms_; publisher_->CheckDeadSubscribers(); erased = publisher_->UnregisterSubscriber(subscriber_id_); @@ -819,8 +824,8 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionDoesntExist) { bool long_polling_connection_replied = false; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&long_polling_connection_replied](Status status, std::function success, - std::function failure) { + [&long_polling_connection_replied]( + Status status, std::function success, std::function failure) { long_polling_connection_replied = true; }; @@ -828,8 +833,8 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionDoesntExist) { /// Test the case where there was a registration, but no connection. /// auto oid = ObjectID::FromRandom(); - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); publisher_->Publish(GeneratePubMessage(oid)); // There was no long polling connection yet. ASSERT_EQ(long_polling_connection_replied, false); @@ -849,8 +854,8 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionDoesntExist) { /// Test the case where there's no connection coming at all when there was a /// registration. - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); publisher_->Publish(GeneratePubMessage(oid)); // No new long polling connection was made until timeout. @@ -865,15 +870,15 @@ TEST_F(PublisherTest, TestUnregisterSubscription) { bool long_polling_connection_replied = false; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&long_polling_connection_replied](Status status, std::function success, - std::function failure) { + [&long_polling_connection_replied]( + Status status, std::function success, std::function failure) { long_polling_connection_replied = true; }; const auto oid = ObjectID::FromRandom(); publisher_->ConnectToSubscriber(request_, &reply, send_reply_callback); - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); ASSERT_EQ(long_polling_connection_replied, false); // Connection should be replied (removed) when the subscriber is unregistered. @@ -883,13 +888,14 @@ TEST_F(PublisherTest, TestUnregisterSubscription) { ASSERT_EQ(long_polling_connection_replied, false); // Make sure when the entries don't exist, it doesn't delete anything. - ASSERT_EQ( - publisher_->UnregisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, ObjectID::FromRandom().Binary()), - 0); ASSERT_EQ(publisher_->UnregisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - NodeID::FromRandom(), oid.Binary()), + subscriber_id_, + ObjectID::FromRandom().Binary()), 0); + ASSERT_EQ( + publisher_->UnregisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, NodeID::FromRandom(), oid.Binary()), + 0); ASSERT_EQ(publisher_->UnregisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, NodeID::FromRandom(), ObjectID::FromRandom().Binary()), @@ -905,16 +911,16 @@ TEST_F(PublisherTest, TestUnregisterSubscriber) { bool long_polling_connection_replied = false; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&long_polling_connection_replied](Status status, std::function success, - std::function failure) { + [&long_polling_connection_replied]( + Status status, std::function success, std::function failure) { long_polling_connection_replied = true; }; // Test basic. const auto oid = ObjectID::FromRandom(); publisher_->ConnectToSubscriber(request_, &reply, send_reply_callback); - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); ASSERT_EQ(long_polling_connection_replied, false); int erased = publisher_->UnregisterSubscriber(subscriber_id_); ASSERT_TRUE(erased); @@ -930,8 +936,8 @@ TEST_F(PublisherTest, TestUnregisterSubscriber) { // Test when connect wasn't done. long_polling_connection_replied = false; - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); erased = publisher_->UnregisterSubscriber(subscriber_id_); ASSERT_TRUE(erased); ASSERT_EQ(long_polling_connection_replied, false); @@ -941,25 +947,25 @@ TEST_F(PublisherTest, TestUnregisterSubscriber) { // Test if registration / unregistration is idempotent. TEST_F(PublisherTest, TestRegistrationIdempotency) { const auto oid = ObjectID::FromRandom(); - ASSERT_TRUE(publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary())); - ASSERT_FALSE(publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary())); - ASSERT_FALSE(publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary())); - ASSERT_FALSE(publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary())); + ASSERT_TRUE(publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary())); + ASSERT_FALSE(publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary())); + ASSERT_FALSE(publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary())); + ASSERT_FALSE(publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary())); ASSERT_FALSE(publisher_->CheckNoLeaks()); - ASSERT_TRUE(publisher_->UnregisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary())); + ASSERT_TRUE(publisher_->UnregisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary())); ASSERT_FALSE(publisher_->UnregisterSubscription( rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary())); ASSERT_TRUE(publisher_->CheckNoLeaks()); - ASSERT_TRUE(publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary())); + ASSERT_TRUE(publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary())); ASSERT_FALSE(publisher_->CheckNoLeaks()); - ASSERT_TRUE(publisher_->UnregisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary())); + ASSERT_TRUE(publisher_->UnregisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary())); } TEST_F(PublisherTest, TestPublishFailure) { @@ -969,8 +975,8 @@ TEST_F(PublisherTest, TestPublishFailure) { std::vector failed_ids; rpc::PubsubLongPollingReply reply; rpc::SendReplyCallback send_reply_callback = - [&reply, &failed_ids](Status status, std::function success, - std::function failure) { + [&reply, &failed_ids]( + Status status, std::function success, std::function failure) { for (int i = 0; i < reply.pub_messages_size(); i++) { const auto &msg = reply.pub_messages(i); RAY_LOG(ERROR) << "ha"; @@ -985,8 +991,8 @@ TEST_F(PublisherTest, TestPublishFailure) { const auto oid = ObjectID::FromRandom(); publisher_->ConnectToSubscriber(request_, &reply, send_reply_callback); - publisher_->RegisterSubscription(rpc::ChannelType::WORKER_OBJECT_EVICTION, - subscriber_id_, oid.Binary()); + publisher_->RegisterSubscription( + rpc::ChannelType::WORKER_OBJECT_EVICTION, subscriber_id_, oid.Binary()); publisher_->PublishFailure(rpc::ChannelType::WORKER_OBJECT_EVICTION, oid.Binary()); ASSERT_EQ(failed_ids[0], oid); } diff --git a/src/ray/pubsub/test/subscriber_test.cc b/src/ray/pubsub/test/subscriber_test.cc index d11d3afaf..0dcddf18f 100644 --- a/src/ray/pubsub/test/subscriber_test.cc +++ b/src/ray/pubsub/test/subscriber_test.cc @@ -52,7 +52,8 @@ class MockWorkerClient : public pubsub::SubscriberClientInterface { return r; } - bool ReplyLongPolling(rpc::ChannelType channel_type, std::vector &object_ids, + bool ReplyLongPolling(rpc::ChannelType channel_type, + std::vector &object_ids, Status status = Status::OK()) { if (long_polling_callbacks.empty()) { return false; @@ -122,13 +123,16 @@ class SubscriberTest : public ::testing::Test { std::vector{rpc::ChannelType::WORKER_OBJECT_EVICTION, rpc::ChannelType::WORKER_REF_REMOVED_CHANNEL, rpc::ChannelType::WORKER_OBJECT_LOCATIONS_CHANNEL}, - /*max_command_batch_size*/ 3, client_pool, &callback_service_); + /*max_command_batch_size*/ 3, + client_pool, + &callback_service_); } const rpc::Address GenerateOwnerAddress( const std::string node_id = NodeID::FromRandom().Binary(), const std::string worker_id = WorkerID::FromRandom().Binary(), - const std::string address = "abc", const int port = 1234) { + const std::string address = "abc", + const int port = 1234) { rpc::Address addr; addr.set_raylet_id(node_id); addr.set_ip_address(address); @@ -144,7 +148,8 @@ class SubscriberTest : public ::testing::Test { return sub_message; } - bool ReplyLongPolling(rpc::ChannelType channel_type, std::vector &object_ids, + bool ReplyLongPolling(rpc::ChannelType channel_type, + std::vector &object_ids, Status status = Status::OK()) { auto success = owner_client->ReplyLongPolling(channel_type, object_ids, status); // Need to call this to invoke callback when the reply comes. @@ -187,9 +192,13 @@ TEST_F(SubscriberTest, TestBasicSubscription) { const auto object_id = ObjectID::FromRandom(); ASSERT_FALSE(subscriber_->Unsubscribe(channel, owner_addr, object_id.Binary())); ASSERT_TRUE(owner_client->ReplyCommandBatch()); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); ASSERT_TRUE(owner_client->ReplyCommandBatch()); ASSERT_TRUE(subscriber_->IsSubscribed(channel, owner_addr, object_id.Binary())); @@ -226,9 +235,13 @@ TEST_F(SubscriberTest, TestSingleLongPollingWithMultipleSubscriptions) { for (int i = 0; i < 5; i++) { const auto object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); ASSERT_TRUE(owner_client->ReplyCommandBatch()); ASSERT_TRUE(subscriber_->IsSubscribed(channel, owner_addr, object_id.Binary())); objects_batched.push_back(object_id); @@ -257,9 +270,13 @@ TEST_F(SubscriberTest, TestMultiLongPollingWithTheSameSubscription) { const auto owner_addr = GenerateOwnerAddress(); const auto object_id = ObjectID::FromRandom(); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); ASSERT_TRUE(owner_client->ReplyCommandBatch()); ASSERT_EQ(owner_client->GetNumberOfInFlightLongPollingRequests(), 1); ASSERT_TRUE(subscriber_->IsSubscribed(channel, owner_addr, object_id.Binary())); @@ -292,9 +309,13 @@ TEST_F(SubscriberTest, TestCallbackNotInvokedForNonSubscribedObject) { const auto owner_addr = GenerateOwnerAddress(); const auto object_id = ObjectID::FromRandom(); const auto object_id_not_subscribed = ObjectID::FromRandom(); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); ASSERT_TRUE(owner_client->ReplyCommandBatch()); // The object information is published. @@ -315,9 +336,12 @@ TEST_F(SubscriberTest, TestSubscribeChannelEntities) { auto failure_callback = EMPTY_FAILURE_CALLBACK; const auto owner_addr = GenerateOwnerAddress(); - subscriber_->SubscribeChannel(std::make_unique(), channel, owner_addr, + subscriber_->SubscribeChannel(std::make_unique(), + channel, + owner_addr, /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscription_callback, + failure_callback); ASSERT_TRUE(owner_client->ReplyCommandBatch()); ASSERT_EQ(owner_client->GetNumberOfInFlightLongPollingRequests(), 1); @@ -361,9 +385,13 @@ TEST_F(SubscriberTest, TestIgnoreBatchAfterUnsubscription) { const auto owner_addr = GenerateOwnerAddress(); const auto object_id = ObjectID::FromRandom(); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); ASSERT_TRUE(owner_client->ReplyCommandBatch()); ASSERT_TRUE(subscriber_->Unsubscribe(channel, owner_addr, object_id.Binary())); ASSERT_TRUE(owner_client->ReplyCommandBatch()); @@ -390,9 +418,12 @@ TEST_F(SubscriberTest, TestIgnoreBatchAfterUnsubscribeFromAll) { auto failure_callback = EMPTY_FAILURE_CALLBACK; const auto owner_addr = GenerateOwnerAddress(); - subscriber_->SubscribeChannel(std::make_unique(), channel, owner_addr, + subscriber_->SubscribeChannel(std::make_unique(), + channel, + owner_addr, /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscription_callback, + failure_callback); ASSERT_TRUE(owner_client->ReplyCommandBatch()); ASSERT_TRUE(subscriber_->UnsubscribeChannel(channel, owner_addr)); ASSERT_TRUE(owner_client->ReplyCommandBatch()); @@ -420,9 +451,13 @@ TEST_F(SubscriberTest, TestLongPollingFailure) { auto failure_callback = [this, object_id](const std::string &key_id, const Status &) { object_failed_to_subscribe_.emplace(object_id); }; - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); ASSERT_TRUE(owner_client->ReplyCommandBatch()); // Long polling failed. @@ -452,9 +487,13 @@ TEST_F(SubscriberTest, TestUnsubscribeInSubscriptionCallback) { ASSERT_TRUE(false); }; - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); ASSERT_TRUE(owner_client->ReplyCommandBatch()); std::vector objects_batched; @@ -480,9 +519,13 @@ TEST_F(SubscriberTest, TestSubUnsubCommandBatchSingleEntry) { const auto owner_addr = GenerateOwnerAddress(); const auto object_id = ObjectID::FromRandom(); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); auto r = owner_client->ReplyCommandBatch(); auto commands = r->commands(); @@ -517,18 +560,30 @@ TEST_F(SubscriberTest, TestSubUnsubCommandBatchMultiEntries) { const auto object_id = ObjectID::FromRandom(); const auto object_id_2 = ObjectID::FromRandom(); // The first batch is always processed right away. - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); // Test multiple entries in the batch before new reply is coming. subscriber_->Unsubscribe(channel, owner_addr, object_id.Binary()); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); - subscriber_->Subscribe(GenerateSubMessage(object_id_2), channel, owner_addr, - object_id_2.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id_2), + channel, + owner_addr, + object_id_2.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); // The long polling request is replied. New batch will be sent. std::vector objects_batched; @@ -583,14 +638,22 @@ TEST_F(SubscriberTest, TestSubUnsubCommandBatchMultiBatch) { // The first 3 will be in the first batch. subscriber_->Unsubscribe(channel, owner_addr, object_id.Binary()); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); subscriber_->Unsubscribe(channel, owner_addr, object_id.Binary()); // Note that this request will be batched in the second batch. - subscriber_->Subscribe(GenerateSubMessage(object_id_2), channel, owner_addr, - object_id_2.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id_2), + channel, + owner_addr, + object_id_2.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); // The long polling request is replied. std::vector objects_batched; @@ -632,16 +695,24 @@ TEST_F(SubscriberTest, TestOnlyOneInFlightCommandBatch) { const auto object_id = ObjectID::FromRandom(); // The first batch is sent right away. There should be no more in flight request until // is is replied. - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); // These two subscribe requests are sent in the next batch. for (int i = 0; i < 2; i++) { const auto object_id = ObjectID::FromRandom(); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); } // The first batch is replied. The second batch should be sent. @@ -668,16 +739,24 @@ TEST_F(SubscriberTest, TestCommandsCleanedUponPublishFailure) { const auto owner_addr = GenerateOwnerAddress(); const auto object_id = ObjectID::FromRandom(); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); // These two subscribe requests are sent to the next batch. for (int i = 0; i < 2; i++) { const auto object_id = ObjectID::FromRandom(); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); } std::vector objects_batched; @@ -709,12 +788,20 @@ TEST_F(SubscriberTest, TestFailureMessagePublished) { const auto id = ObjectID::FromBinary(key_id); object_failed_to_subscribe_.emplace(id); }; - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id2.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id2.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); ASSERT_TRUE(owner_client->ReplyCommandBatch()); // Failure message is published. @@ -748,9 +835,13 @@ TEST_F(SubscriberTest, TestIsSubscribed) { ASSERT_FALSE(subscriber_->Unsubscribe(channel, owner_addr, object_id.Binary())); ASSERT_FALSE(subscriber_->IsSubscribed(channel, owner_addr, object_id.Binary())); - subscriber_->Subscribe(GenerateSubMessage(object_id), channel, owner_addr, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, - subscription_callback, failure_callback); + subscriber_->Subscribe(GenerateSubMessage(object_id), + channel, + owner_addr, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + failure_callback); ASSERT_TRUE(subscriber_->IsSubscribed(channel, owner_addr, object_id.Binary())); ASSERT_TRUE(subscriber_->Unsubscribe(channel, owner_addr, object_id.Binary())); diff --git a/src/ray/raylet/agent_manager.cc b/src/ray/raylet/agent_manager.cc index 5c023d0c0..140aeba85 100644 --- a/src/ray/raylet/agent_manager.cc +++ b/src/ray/raylet/agent_manager.cc @@ -136,7 +136,8 @@ void AgentManager::StartAgent() { } void AgentManager::CreateRuntimeEnv( - const JobID &job_id, const std::string &serialized_runtime_env, + const JobID &job_id, + const std::string &serialized_runtime_env, const std::string &serialized_allocated_resource_instances, CreateRuntimeEnvCallback callback) { // If the agent cannot be started, fail the request. @@ -153,7 +154,8 @@ void AgentManager::CreateRuntimeEnv( // and causing a segfault. delay_executor_( [callback = std::move(callback), error_message] { - callback(/*successful=*/false, /*serialized_runtime_env_context=*/"", + callback(/*successful=*/false, + /*serialized_runtime_env_context=*/"", /*setup_error_message*/ error_message); }, 0); @@ -170,7 +172,8 @@ void AgentManager::CreateRuntimeEnv( RAY_LOG(WARNING) << error_message; delay_executor_( [callback = std::move(callback), - serialized_runtime_env = std::move(serialized_runtime_env), error_message] { + serialized_runtime_env = std::move(serialized_runtime_env), + error_message] { callback(/*successful=*/false, /*serialized_runtime_env_context=*/serialized_runtime_env, /*setup_error_message*/ error_message); @@ -183,10 +186,15 @@ void AgentManager::CreateRuntimeEnv( << "Runtime env agent is not registered yet. Will retry CreateRuntimeEnv later: " << serialized_runtime_env; delay_executor_( - [this, job_id, serialized_runtime_env, serialized_allocated_resource_instances, + [this, + job_id, + serialized_runtime_env, + serialized_allocated_resource_instances, callback = std::move(callback)] { - CreateRuntimeEnv(job_id, serialized_runtime_env, - serialized_allocated_resource_instances, callback); + CreateRuntimeEnv(job_id, + serialized_runtime_env, + serialized_allocated_resource_instances, + callback); }, RayConfig::instance().agent_manager_retry_interval_ms()); return; @@ -197,17 +205,23 @@ void AgentManager::CreateRuntimeEnv( request.set_serialized_allocated_resource_instances( serialized_allocated_resource_instances); runtime_env_agent_client_->CreateRuntimeEnv( - request, [this, job_id, serialized_runtime_env, - serialized_allocated_resource_instances, callback = std::move(callback)]( - const Status &status, const rpc::CreateRuntimeEnvReply &reply) { + request, + [this, + job_id, + serialized_runtime_env, + serialized_allocated_resource_instances, + callback = std::move(callback)](const Status &status, + const rpc::CreateRuntimeEnvReply &reply) { if (status.ok()) { if (reply.status() == rpc::AGENT_RPC_STATUS_OK) { - callback(true, reply.serialized_runtime_env_context(), + callback(true, + reply.serialized_runtime_env_context(), /*setup_error_message*/ ""); } else { RAY_LOG(INFO) << "Failed to create runtime env: " << serialized_runtime_env << ", error message: " << reply.error_message(); - callback(false, reply.serialized_runtime_env_context(), + callback(false, + reply.serialized_runtime_env_context(), /*setup_error_message*/ reply.error_message()); } @@ -218,10 +232,15 @@ void AgentManager::CreateRuntimeEnv( << ", status = " << status << ", maybe there are some network problems, will retry it later."; delay_executor_( - [this, job_id, serialized_runtime_env, - serialized_allocated_resource_instances, callback = std::move(callback)] { - CreateRuntimeEnv(job_id, serialized_runtime_env, - serialized_allocated_resource_instances, callback); + [this, + job_id, + serialized_runtime_env, + serialized_allocated_resource_instances, + callback = std::move(callback)] { + CreateRuntimeEnv(job_id, + serialized_runtime_env, + serialized_allocated_resource_instances, + callback); }, RayConfig::instance().agent_manager_retry_interval_ms()); } @@ -241,27 +260,27 @@ void AgentManager::DeleteURIs(const std::vector &uris, for (const auto &uri : uris) { request.add_uris(uri); } - runtime_env_agent_client_->DeleteURIs(request, [this, uris, callback]( - Status status, - const rpc::DeleteURIsReply &reply) { - if (status.ok()) { - if (reply.status() == rpc::AGENT_RPC_STATUS_OK) { - callback(true); - } else { - // TODO(sang): Find a better way to delivering error messages in this case. - RAY_LOG(ERROR) << "Failed to delete URIs" - << ", error message: " << reply.error_message(); - callback(false); - } + runtime_env_agent_client_->DeleteURIs( + request, [this, uris, callback](Status status, const rpc::DeleteURIsReply &reply) { + if (status.ok()) { + if (reply.status() == rpc::AGENT_RPC_STATUS_OK) { + callback(true); + } else { + // TODO(sang): Find a better way to delivering error messages in this case. + RAY_LOG(ERROR) << "Failed to delete URIs" + << ", error message: " << reply.error_message(); + callback(false); + } - } else { - RAY_LOG(ERROR) << "Failed to delete URIs" - << ", status = " << status - << ", maybe there are some network problems, will retry it later."; - delay_executor_([this, uris, callback] { DeleteURIs(uris, callback); }, - RayConfig::instance().agent_manager_retry_interval_ms()); - } - }); + } else { + RAY_LOG(ERROR) + << "Failed to delete URIs" + << ", status = " << status + << ", maybe there are some network problems, will retry it later."; + delay_executor_([this, uris, callback] { DeleteURIs(uris, callback); }, + RayConfig::instance().agent_manager_retry_interval_ms()); + } + }); } } // namespace raylet diff --git a/src/ray/raylet/agent_manager.h b/src/ray/raylet/agent_manager.h index 1340818c6..51052f9e0 100644 --- a/src/ray/raylet/agent_manager.h +++ b/src/ray/raylet/agent_manager.h @@ -53,7 +53,8 @@ class AgentManager : public rpc::AgentManagerServiceHandler { std::vector agent_commands; }; - explicit AgentManager(Options options, DelayExecutorFn delay_executor, + explicit AgentManager(Options options, + DelayExecutorFn delay_executor, RuntimeEnvAgentClientFactoryFn runtime_env_agent_client_factory, bool start_agent = true /* for test */) : options_(std::move(options)), @@ -71,7 +72,8 @@ class AgentManager : public rpc::AgentManagerServiceHandler { /// Request agent to create a runtime env. /// \param[in] runtime_env The runtime env. virtual void CreateRuntimeEnv( - const JobID &job_id, const std::string &serialized_runtime_env, + const JobID &job_id, + const std::string &serialized_runtime_env, const std::string &serialized_allocated_resource_instances, CreateRuntimeEnvCallback callback); diff --git a/src/ray/raylet/local_object_manager.cc b/src/ray/raylet/local_object_manager.cc index dfefda528..8663011bd 100644 --- a/src/ray/raylet/local_object_manager.cc +++ b/src/ray/raylet/local_object_manager.cc @@ -74,8 +74,8 @@ void LocalObjectManager::PinObjectsAndWaitForFree( const auto object_eviction_msg = msg.worker_object_eviction_message(); const auto object_id = ObjectID::FromBinary(object_eviction_msg.object_id()); ReleaseFreedObject(object_id); - core_worker_subscriber_->Unsubscribe(rpc::ChannelType::WORKER_OBJECT_EVICTION, - owner_address, object_id.Binary()); + core_worker_subscriber_->Unsubscribe( + rpc::ChannelType::WORKER_OBJECT_EVICTION, owner_address, object_id.Binary()); }; // Callback that is invoked when the owner of the object id is dead. @@ -88,10 +88,13 @@ void LocalObjectManager::PinObjectsAndWaitForFree( auto sub_message = std::make_unique(); sub_message->mutable_worker_object_eviction_message()->Swap(wait_request.get()); - RAY_CHECK(core_worker_subscriber_->Subscribe( - std::move(sub_message), rpc::ChannelType::WORKER_OBJECT_EVICTION, owner_address, - object_id.Binary(), /*subscribe_done_callback=*/nullptr, subscription_callback, - owner_dead_callback)); + RAY_CHECK(core_worker_subscriber_->Subscribe(std::move(sub_message), + rpc::ChannelType::WORKER_OBJECT_EVICTION, + owner_address, + object_id.Binary(), + /*subscribe_done_callback=*/nullptr, + subscription_callback, + owner_dead_callback)); } } @@ -191,45 +194,48 @@ bool LocalObjectManager::SpillObjectsOfSize(int64_t num_bytes_to_spill) { RAY_LOG(DEBUG) << "Spilling objects of total size " << bytes_to_spill << " num objects " << objects_to_spill.size(); auto start_time = absl::GetCurrentTimeNanos(); - SpillObjectsInternal(objects_to_spill, [this, bytes_to_spill, objects_to_spill, - start_time](const Status &status) { - if (!status.ok()) { - RAY_LOG(DEBUG) << "Failed to spill objects: " << status.ToString(); - } else { - auto now = absl::GetCurrentTimeNanos(); - RAY_LOG(DEBUG) << "Spilled " << bytes_to_spill << " bytes in " - << (now - start_time) / 1e6 << "ms"; - spilled_bytes_total_ += bytes_to_spill; - spilled_objects_total_ += objects_to_spill.size(); - // Adjust throughput timing to account for concurrent spill operations. - spill_time_total_s_ += (now - std::max(start_time, last_spill_finish_ns_)) / 1e9; - if (now - last_spill_log_ns_ > 1e9) { - last_spill_log_ns_ = now; - std::stringstream msg; - // Keep :info_message: in sync with LOG_PREFIX_INFO_MESSAGE in ray_constants.py. - msg << ":info_message:Spilled " - << static_cast(spilled_bytes_total_ / (1024 * 1024)) << " MiB, " - << spilled_objects_total_ << " objects, write throughput " - << static_cast(spilled_bytes_total_ / (1024 * 1024) / - spill_time_total_s_) - << " MiB/s."; - if (next_spill_error_log_bytes_ > 0 && - spilled_bytes_total_ >= next_spill_error_log_bytes_) { - // Add an advisory the first time this is logged. - if (next_spill_error_log_bytes_ == - RayConfig::instance().verbose_spill_logs()) { - msg << " Set RAY_verbose_spill_logs=0 to disable this message."; - } - // Exponential backoff on the spill messages. - next_spill_error_log_bytes_ *= 2; - RAY_LOG(ERROR) << msg.str(); + SpillObjectsInternal( + objects_to_spill, + [this, bytes_to_spill, objects_to_spill, start_time](const Status &status) { + if (!status.ok()) { + RAY_LOG(DEBUG) << "Failed to spill objects: " << status.ToString(); } else { - RAY_LOG(INFO) << msg.str(); + auto now = absl::GetCurrentTimeNanos(); + RAY_LOG(DEBUG) << "Spilled " << bytes_to_spill << " bytes in " + << (now - start_time) / 1e6 << "ms"; + spilled_bytes_total_ += bytes_to_spill; + spilled_objects_total_ += objects_to_spill.size(); + // Adjust throughput timing to account for concurrent spill operations. + spill_time_total_s_ += + (now - std::max(start_time, last_spill_finish_ns_)) / 1e9; + if (now - last_spill_log_ns_ > 1e9) { + last_spill_log_ns_ = now; + std::stringstream msg; + // Keep :info_message: in sync with LOG_PREFIX_INFO_MESSAGE in + // ray_constants.py. + msg << ":info_message:Spilled " + << static_cast(spilled_bytes_total_ / (1024 * 1024)) << " MiB, " + << spilled_objects_total_ << " objects, write throughput " + << static_cast(spilled_bytes_total_ / (1024 * 1024) / + spill_time_total_s_) + << " MiB/s."; + if (next_spill_error_log_bytes_ > 0 && + spilled_bytes_total_ >= next_spill_error_log_bytes_) { + // Add an advisory the first time this is logged. + if (next_spill_error_log_bytes_ == + RayConfig::instance().verbose_spill_logs()) { + msg << " Set RAY_verbose_spill_logs=0 to disable this message."; + } + // Exponential backoff on the spill messages. + next_spill_error_log_bytes_ *= 2; + RAY_LOG(ERROR) << msg.str(); + } else { + RAY_LOG(INFO) << msg.str(); + } + } + last_spill_finish_ns_ = now; } - } - last_spill_finish_ns_ = now; - } - }); + }); return true; } return false; @@ -299,8 +305,9 @@ void LocalObjectManager::SpillObjectsInternal( } } io_worker->rpc_client()->SpillObjects( - request, [this, requested_objects_to_spill, callback, io_worker]( - const ray::Status &status, const rpc::SpillObjectsReply &r) { + request, + [this, requested_objects_to_spill, callback, io_worker]( + const ray::Status &status, const rpc::SpillObjectsReply &r) { { absl::MutexLock lock(&mutex_); num_active_workers_ -= 1; @@ -421,7 +428,8 @@ std::string LocalObjectManager::GetLocalSpilledObjectURL(const ObjectID &object_ } void LocalObjectManager::AsyncRestoreSpilledObject( - const ObjectID &object_id, const std::string &object_url, + const ObjectID &object_id, + const std::string &object_url, std::function callback) { if (objects_pending_restore_.count(object_id) > 0) { // If the same object is restoring, we dedup here. @@ -536,8 +544,9 @@ void LocalObjectManager::DeleteSpilledObjects(std::vector &urls_to_ request.add_spilled_objects_url(std::move(url)); } io_worker->rpc_client()->DeleteSpilledObjects( - request, [this, io_worker](const ray::Status &status, - const rpc::DeleteSpilledObjectsReply &reply) { + request, + [this, io_worker](const ray::Status &status, + const rpc::DeleteSpilledObjectsReply &reply) { io_worker_pool_.PushDeleteWorker(io_worker); if (!status.ok()) { RAY_LOG(ERROR) << "Failed to send delete spilled object request: " diff --git a/src/ray/raylet/local_object_manager.h b/src/ray/raylet/local_object_manager.h index c266f202c..89c2cd263 100644 --- a/src/ray/raylet/local_object_manager.h +++ b/src/ray/raylet/local_object_manager.h @@ -37,10 +37,16 @@ namespace raylet { class LocalObjectManager { public: LocalObjectManager( - const NodeID &node_id, std::string self_node_address, int self_node_port, - size_t free_objects_batch_size, int64_t free_objects_period_ms, - IOWorkerPoolInterface &io_worker_pool, rpc::CoreWorkerClientPool &owner_client_pool, - int max_io_workers, int64_t min_spilling_size, bool is_external_storage_type_fs, + const NodeID &node_id, + std::string self_node_address, + int self_node_port, + size_t free_objects_batch_size, + int64_t free_objects_period_ms, + IOWorkerPoolInterface &io_worker_pool, + rpc::CoreWorkerClientPool &owner_client_pool, + int max_io_workers, + int64_t min_spilling_size, + bool is_external_storage_type_fs, int64_t max_fused_object_count, std::function &)> on_objects_freed, std::function is_plasma_object_spillable, @@ -99,7 +105,8 @@ class LocalObjectManager { /// \param object_url The URL where the object is spilled. /// \param callback A callback to call when the restoration is done. /// Status will contain the error during restoration, if any. - void AsyncRestoreSpilledObject(const ObjectID &object_id, const std::string &object_url, + void AsyncRestoreSpilledObject(const ObjectID &object_id, + const std::string &object_url, std::function callback); /// Clear any freed objects. This will trigger the callback for freed diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 5673cb2b3..2cb2ada95 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -35,13 +35,17 @@ DEFINE_string(node_ip_address, "", "The ip address of this node."); DEFINE_string(gcs_address, "", "The address of the GCS server, including IP and port."); DEFINE_string(redis_address, "", "The IP address of redis server."); DEFINE_int32(redis_port, -1, "The port of redis server."); -DEFINE_int32(min_worker_port, 0, +DEFINE_int32(min_worker_port, + 0, "The lowest port that workers' gRPC servers will bind on."); -DEFINE_int32(max_worker_port, 0, +DEFINE_int32(max_worker_port, + 0, "The highest port that workers' gRPC servers will bind on."); -DEFINE_string(worker_port_list, "", +DEFINE_string(worker_port_list, + "", "An explicit list of ports that workers' gRPC servers will bind on."); -DEFINE_int32(num_initial_python_workers_for_first_job, 0, +DEFINE_int32(num_initial_python_workers_for_first_job, + 0, "Number of initial Python workers for the first job."); DEFINE_int32(maximum_startup_concurrency, 1, "Maximum startup concurrency."); DEFINE_string(static_resource_list, "", "The static resource list of this node."); @@ -49,7 +53,8 @@ DEFINE_string(python_worker_command, "", "Python worker command."); DEFINE_string(java_worker_command, "", "Java worker command."); DEFINE_string(agent_command, "", "Dashboard agent command."); DEFINE_string(cpp_worker_command, "", "CPP worker command."); -DEFINE_string(native_library_path, "", +DEFINE_string(native_library_path, + "", "The native library path which includes the core libraries."); DEFINE_string(redis_password, "", "The password of redis."); DEFINE_string(temp_dir, "", "Temporary directory."); @@ -60,10 +65,12 @@ DEFINE_int32(ray_debugger_external, 0, "Make Ray debugger externally accessible. // store options DEFINE_int64(object_store_memory, -1, "The initial memory of the object store."); #ifdef __linux__ -DEFINE_string(plasma_directory, "/dev/shm", +DEFINE_string(plasma_directory, + "/dev/shm", "The shared memory directory of the object store."); #else -DEFINE_string(plasma_directory, "/tmp", +DEFINE_string(plasma_directory, + "/tmp", "The shared memory directory of the object store."); #endif DEFINE_bool(huge_pages, false, "Enable huge pages."); @@ -71,7 +78,8 @@ DEFINE_bool(huge_pages, false, "Enable huge pages."); int main(int argc, char *argv[]) { InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLog::ShutDownRayLog, + argv[0], ray::RayLogLevel::INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(argv[0]); @@ -129,9 +137,12 @@ int main(int argc, char *argv[]) { } else { // Async context is not used by `redis_client_` in `gcs_client`, so we set // `enable_async_conn` as false. - ray::gcs::GcsClientOptions client_options( - redis_address, redis_port, redis_password, /*enable_sync_conn=*/true, - /*enable_async_conn=*/false, /*enable_subscribe_conn=*/true); + ray::gcs::GcsClientOptions client_options(redis_address, + redis_port, + redis_password, + /*enable_sync_conn=*/true, + /*enable_async_conn=*/false, + /*enable_subscribe_conn=*/true); gcs_client = std::make_shared(client_options); } @@ -262,14 +273,19 @@ int main(int argc, char *argv[]) { ray::stats::Init(global_tags, metrics_agent_port); // Initialize the node manager. - raylet = std::make_unique( - main_service, raylet_socket_name, node_ip_address, node_manager_config, - object_manager_config, gcs_client, metrics_export_port); + raylet = std::make_unique(main_service, + raylet_socket_name, + node_ip_address, + node_manager_config, + object_manager_config, + gcs_client, + metrics_export_port); // Initialize event framework. if (RayConfig::instance().event_log_reporter_enabled() && !log_dir.empty()) { ray::RayEventInit(ray::rpc::Event_SourceType::Event_SourceType_RAYLET, - {{"node_id", raylet->GetNodeId().Hex()}}, log_dir, + {{"node_id", raylet->GetNodeId().Hex()}}, + log_dir, RayConfig::instance().event_level()); }; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e866e314d..75d4d1749 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -172,7 +172,8 @@ void HeartbeatSender::Heartbeat() { })); } -NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self_node_id, +NodeManager::NodeManager(instrumented_io_context &io_service, + const NodeID &self_node_id, const NodeManagerConfig &config, const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client) @@ -180,11 +181,18 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self io_service_(io_service), gcs_client_(gcs_client), worker_pool_( - io_service, self_node_id_, config.node_manager_address, - config.num_workers_soft_limit, config.num_initial_python_workers_for_first_job, - config.maximum_startup_concurrency, config.min_worker_port, - config.max_worker_port, config.worker_ports, gcs_client_, - config.worker_commands, config.native_library_path, + io_service, + self_node_id_, + config.node_manager_address, + config.num_workers_soft_limit, + config.num_initial_python_workers_for_first_job, + config.maximum_startup_concurrency, + config.min_worker_port, + config.max_worker_port, + config.worker_ports, + gcs_client_, + config.worker_commands, + config.native_library_path, /*starting_worker_timeout_callback=*/ [this] { cluster_task_manager_->ScheduleAndDispatchTasks(); }, config.ray_debugger_external, @@ -205,7 +213,9 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self }, &io_service_)), object_directory_(std::make_unique( - io_service_, gcs_client_, core_worker_subscriber_.get(), + io_service_, + gcs_client_, + core_worker_subscriber_.get(), /*owner_client_pool=*/&worker_rpc_pool_, /*max_object_report_batch_size=*/ RayConfig::instance().max_object_report_batch_size(), @@ -215,11 +225,15 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self MarkObjectsAsFailed(error_type, {ref}, JobID::Nil()); })), object_manager_( - io_service, self_node_id, object_manager_config, object_directory_.get(), - [this](const ObjectID &object_id, const std::string &object_url, + io_service, + self_node_id, + object_manager_config, + object_directory_.get(), + [this](const ObjectID &object_id, + const std::string &object_url, std::function callback) { - GetLocalObjectManager().AsyncRestoreSpilledObject(object_id, object_url, - callback); + GetLocalObjectManager().AsyncRestoreSpilledObject( + object_id, object_url, callback); }, /*get_spilled_object_url=*/ [this](const ObjectID &object_id) { @@ -259,8 +273,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self [this](const ObjectID &object_id) { rpc::ObjectReference ref; ref.set_object_id(object_id.Binary()); - MarkObjectsAsFailed(rpc::ErrorType::OBJECT_FETCH_TIMED_OUT, {ref}, - JobID::Nil()); + MarkObjectsAsFailed( + rpc::ErrorType::OBJECT_FETCH_TIMED_OUT, {ref}, JobID::Nil()); }), periodical_runner_(io_service), report_resources_period_ms_(config.report_resources_period_ms), @@ -275,16 +289,20 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self [this](std::function fn, int64_t delay_ms) { RAY_UNUSED(execute_after(io_service_, fn, delay_ms)); }), - node_manager_server_("NodeManager", config.node_manager_port, + node_manager_server_("NodeManager", + config.node_manager_port, config.node_manager_address == "127.0.0.1"), node_manager_service_(io_service, *this), agent_manager_service_handler_( new DefaultAgentManagerServiceHandler(agent_manager_)), agent_manager_service_(io_service, *agent_manager_service_handler_), local_object_manager_( - self_node_id_, config.node_manager_address, config.node_manager_port, + self_node_id_, + config.node_manager_address, + config.node_manager_port, RayConfig::instance().free_objects_batch_size(), - RayConfig::instance().free_objects_period_milliseconds(), worker_pool_, + RayConfig::instance().free_objects_period_milliseconds(), + worker_pool_, worker_rpc_pool_, /*max_io_workers*/ config.max_io_workers, /*min_spilling_size*/ config.min_spilling_size, @@ -314,7 +332,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self cluster_resource_scheduler_ = std::shared_ptr(new ClusterResourceScheduler( scheduling::NodeID(self_node_id_.Binary()), - local_resources.GetTotalResources().GetResourceMap(), *gcs_client_, + local_resources.GetTotalResources().GetResourceMap(), + *gcs_client_, [this]() { if (RayConfig::instance().scheduler_report_pinned_bytes_only()) { return local_object_manager_.GetPinnedBytes(); @@ -351,7 +370,10 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self local_task_manager_ = std::make_shared( self_node_id_, std::dynamic_pointer_cast(cluster_resource_scheduler_), - dependency_manager_, is_owner_alive, get_node_info_func, worker_pool_, + dependency_manager_, + is_owner_alive, + get_node_info_func, + worker_pool_, leased_workers_, [this](const std::vector &object_ids, std::vector> *results) { @@ -361,7 +383,9 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self cluster_task_manager_ = std::make_shared( self_node_id_, std::dynamic_pointer_cast(cluster_resource_scheduler_), - get_node_info_func, announce_infeasible_task, local_task_manager_); + get_node_info_func, + announce_infeasible_task, + local_task_manager_); placement_group_resource_manager_ = std::make_shared( std::dynamic_pointer_cast(cluster_resource_scheduler_)); @@ -381,7 +405,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self for (auto &arg : agent_command_line) { auto node_manager_port_position = arg.find(kNodeManagerPortPlaceholder); if (node_manager_port_position != std::string::npos) { - arg.replace(node_manager_port_position, strlen(kNodeManagerPortPlaceholder), + arg.replace(node_manager_port_position, + strlen(kNodeManagerPortPlaceholder), std::to_string(GetServerPort())); } } @@ -642,7 +667,8 @@ void NodeManager::DoLocalGC() { void NodeManager::HandleRequestObjectSpillage( const rpc::RequestObjectSpillageRequest &request, - rpc::RequestObjectSpillageReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::RequestObjectSpillageReply *reply, + rpc::SendReplyCallback send_reply_callback) { const auto &object_id = ObjectID::FromBinary(request.object_id()); RAY_LOG(DEBUG) << "Received RequestObjectSpillage for object " << object_id; local_object_manager_.SpillObjects( @@ -661,7 +687,8 @@ void NodeManager::HandleRequestObjectSpillage( void NodeManager::HandleReleaseUnusedBundles( const rpc::ReleaseUnusedBundlesRequest &request, - rpc::ReleaseUnusedBundlesReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::ReleaseUnusedBundlesReply *reply, + rpc::SendReplyCallback send_reply_callback) { RAY_LOG(DEBUG) << "Releasing unused bundles."; std::unordered_set in_use_bundles; for (int index = 0; index < request.bundles_in_use_size(); ++index) { @@ -769,9 +796,11 @@ void NodeManager::WarnResourceDeadlock() { RAY_LOG(WARNING) << error_message_str; RAY_LOG_EVERY_MS(WARNING, 10 * 1000) << cluster_task_manager_->DebugStr(); if (RayConfig::instance().legacy_scheduler_warnings()) { - auto error_data_ptr = gcs::CreateErrorTableData( - "resource_deadlock", error_message_str, current_time_ms(), - exemplar.GetTaskSpecification().JobId()); + auto error_data_ptr = + gcs::CreateErrorTableData("resource_deadlock", + error_message_str, + current_time_ms(), + exemplar.GetTaskSpecification().JobId()); RAY_CHECK_OK(gcs_client_->Errors().AsyncReportJobError(error_data_ptr, nullptr)); } } @@ -923,7 +952,8 @@ void NodeManager::ResourceCreateUpdated(const NodeID &node_id, const std::string &resource_label = resource_pair.first; const double &new_resource_capacity = resource_pair.second; cluster_resource_scheduler_->GetClusterResourceManager().UpdateResourceCapacity( - scheduling::NodeID(node_id.Binary()), scheduling::ResourceID(resource_label), + scheduling::NodeID(node_id.Binary()), + scheduling::ResourceID(resource_label), new_resource_capacity); } RAY_LOG(DEBUG) << "[ResourceCreateUpdated] Updated cluster_resource_map."; @@ -1047,7 +1077,8 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr & // TODO(ekl) this is still used from core worker even in direct call mode to // finish up get requests. auto message = flatbuffers::GetRoot(message_data); - AsyncResolveObjectsFinish(client, from_flatbuf(*message->task_id()), + AsyncResolveObjectsFinish(client, + from_flatbuf(*message->task_id()), /*was_blocked*/ true); } break; case protocol::MessageType::WaitRequest: { @@ -1097,9 +1128,16 @@ void NodeManager::ProcessRegisterClientRequestMessage( } else { RAY_CHECK(job_id.IsNil()); } - auto worker = std::dynamic_pointer_cast(std::make_shared( - job_id, runtime_env_hash, worker_id, language, worker_type, worker_ip_address, - client, client_call_manager_, worker_startup_token)); + auto worker = std::dynamic_pointer_cast( + std::make_shared(job_id, + runtime_env_hash, + worker_id, + language, + worker_type, + worker_ip_address, + client, + client_call_manager_, + worker_startup_token)); auto send_reply_callback = [this, client, job_id](Status status, int assigned_port) { flatbuffers::FlatBufferBuilder fbb; @@ -1108,14 +1146,19 @@ void NodeManager::ProcessRegisterClientRequestMessage( if (job_config != boost::none) { serialized_job_config = (*job_config).SerializeAsString(); } - auto reply = ray::protocol::CreateRegisterClientReply( - fbb, status.ok(), fbb.CreateString(status.ToString()), - to_flatbuf(fbb, self_node_id_), assigned_port, - fbb.CreateString(serialized_job_config)); + auto reply = + ray::protocol::CreateRegisterClientReply(fbb, + status.ok(), + fbb.CreateString(status.ToString()), + to_flatbuf(fbb, self_node_id_), + assigned_port, + fbb.CreateString(serialized_job_config)); fbb.Finish(reply); client->WriteMessageAsync( - static_cast(protocol::MessageType::RegisterClientReply), fbb.GetSize(), - fbb.GetBufferPointer(), [this, client](const ray::Status &status) { + static_cast(protocol::MessageType::RegisterClientReply), + fbb.GetSize(), + fbb.GetBufferPointer(), + [this, client](const ray::Status &status) { if (!status.ok()) { DisconnectClient(client); } @@ -1125,8 +1168,8 @@ void NodeManager::ProcessRegisterClientRequestMessage( worker_type == rpc::WorkerType::SPILL_WORKER || worker_type == rpc::WorkerType::RESTORE_WORKER) { // Register the new worker. - auto status = worker_pool_.RegisterWorker(worker, pid, worker_startup_token, - send_reply_callback); + auto status = worker_pool_.RegisterWorker( + worker, pid, worker_startup_token, send_reply_callback); if (!status.ok()) { // If the worker failed to register to Raylet, trigger task dispatching here to // allow new worker processes to be started (if capped by @@ -1144,12 +1187,16 @@ void NodeManager::ProcessRegisterClientRequestMessage( job_config.ParseFromString(message->serialized_job_config()->str()); // Send the reply callback only after registration fully completes at the GCS. - auto cb = [this, worker_ip_address, pid, job_id, job_config, + auto cb = [this, + worker_ip_address, + pid, + job_id, + job_config, send_reply_callback = std::move(send_reply_callback)](const Status &status, int assigned_port) { if (status.ok()) { - auto job_data_ptr = gcs::CreateJobTableData(job_id, /*is_dead*/ false, - worker_ip_address, pid, job_config); + auto job_data_ptr = gcs::CreateJobTableData( + job_id, /*is_dead*/ false, worker_ip_address, pid, job_config); RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd( job_data_ptr, [send_reply_callback = std::move(send_reply_callback), assigned_port]( @@ -1251,9 +1298,13 @@ void NodeManager::DisconnectClient(const std::shared_ptr &clie << ", worker_id: " << worker->WorkerId(); } // Publish the worker failure. - auto worker_failure_data_ptr = gcs::CreateWorkerFailureData( - self_node_id_, worker->WorkerId(), worker->IpAddress(), worker->Port(), - time(nullptr), disconnect_type, creation_task_exception); + auto worker_failure_data_ptr = gcs::CreateWorkerFailureData(self_node_id_, + worker->WorkerId(), + worker->IpAddress(), + worker->Port(), + time(nullptr), + disconnect_type, + creation_task_exception); RAY_CHECK_OK( gcs_client_->Workers().AsyncReportWorkerFailure(worker_failure_data_ptr, nullptr)); @@ -1375,7 +1426,10 @@ void NodeManager::ProcessFetchOrReconstructMessage( // pulled from remote node managers. If an object's owner dies, an error // will be stored as the object's value. const TaskID task_id = from_flatbuf(*message->task_id()); - AsyncResolveObjects(client, refs, task_id, /*ray_get=*/true, + AsyncResolveObjects(client, + refs, + task_id, + /*ray_get=*/true, /*mark_worker_blocked*/ message->mark_worker_blocked()); } } @@ -1412,12 +1466,17 @@ void NodeManager::ProcessWaitRequestMessage( // already local. Missing objects will be pulled from remote node managers. // If an object's owner dies, an error will be stored as the object's // value. - AsyncResolveObjects(client, refs, current_task_id, /*ray_get=*/false, + AsyncResolveObjects(client, + refs, + current_task_id, + /*ray_get=*/false, /*mark_worker_blocked*/ was_blocked); } uint64_t num_required_objects = static_cast(message->num_ready_objects()); wait_manager_.Wait( - object_ids, message->timeout(), num_required_objects, + object_ids, + message->timeout(), + num_required_objects, [this, resolve_objects, was_blocked, client, current_task_id]( std::vector ready, std::vector remaining) { // Write the data. @@ -1428,7 +1487,8 @@ void NodeManager::ProcessWaitRequestMessage( auto status = client->WriteMessage(static_cast(protocol::MessageType::WaitReply), - fbb.GetSize(), fbb.GetBufferPointer()); + fbb.GetSize(), + fbb.GetBufferPointer()); if (status.ok()) { // The client is unblocked now because the wait call has returned. if (resolve_objects) { @@ -1452,10 +1512,15 @@ void NodeManager::ProcessWaitForDirectActorCallArgsRequestMessage( // managers or store an error if the objects have failed. const auto refs = FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses()); - AsyncResolveObjects(client, refs, TaskID::Nil(), /*ray_get=*/false, + AsyncResolveObjects(client, + refs, + TaskID::Nil(), + /*ray_get=*/false, /*mark_worker_blocked*/ false); wait_manager_.Wait( - object_ids, -1, object_ids.size(), + object_ids, + -1, + object_ids.size(), [this, client, tag](std::vector ready, std::vector remaining) { RAY_CHECK(remaining.empty()); std::shared_ptr worker = @@ -1480,7 +1545,8 @@ void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) { } void NodeManager::HandleUpdateResourceUsage( - const rpc::UpdateResourceUsageRequest &request, rpc::UpdateResourceUsageReply *reply, + const rpc::UpdateResourceUsageRequest &request, + rpc::UpdateResourceUsageReply *reply, rpc::SendReplyCallback send_reply_callback) { rpc::ResourceUsageBroadcastData resource_usage_batch; resource_usage_batch.ParseFromString(request.serialized_resource_usage_batch()); @@ -1531,7 +1597,8 @@ void NodeManager::HandleUpdateResourceUsage( void NodeManager::HandleRequestResourceReport( const rpc::RequestResourceReportRequest &request, - rpc::RequestResourceReportReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::RequestResourceReportReply *reply, + rpc::SendReplyCallback send_reply_callback) { auto resources_data = reply->mutable_resources(); FillResourceReport(*resources_data); resources_data->set_cluster_full_of_actors_detected(resource_deadlock_warned_ >= 1); @@ -1540,7 +1607,8 @@ void NodeManager::HandleRequestResourceReport( } void NodeManager::HandleReportWorkerBacklog( - const rpc::ReportWorkerBacklogRequest &request, rpc::ReportWorkerBacklogReply *reply, + const rpc::ReportWorkerBacklogRequest &request, + rpc::ReportWorkerBacklogReply *reply, rpc::SendReplyCallback send_reply_callback) { const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); local_task_manager_->ClearWorkerBacklog(worker_id); @@ -1549,8 +1617,8 @@ void NodeManager::HandleReportWorkerBacklog( const TaskSpecification resource_spec(backlog_report.resource_spec()); const SchedulingClass scheduling_class = resource_spec.GetSchedulingClass(); RAY_CHECK(seen.find(scheduling_class) == seen.end()); - local_task_manager_->SetWorkerBacklog(scheduling_class, worker_id, - backlog_report.backlog_size()); + local_task_manager_->SetWorkerBacklog( + scheduling_class, worker_id, backlog_report.backlog_size()); } send_reply_callback(Status::OK(), nullptr, nullptr); } @@ -1581,43 +1649,46 @@ void NodeManager::HandleRequestWorkerLease(const rpc::RequestWorkerLeaseRequest worker_pool_.PrestartWorkers(task_spec, request.backlog_size(), available_cpus); } - auto send_reply_callback_wrapper = [this, is_actor_creation_task, actor_id, reply, - send_reply_callback]( - Status status, std::function success, - std::function failure) { - // If resources are not enough due to normal tasks' preemption - // for GCS based actor scheduling, return a rejection - // with normal task resource usages so GCS can update - // its resource view of this raylet. - if (reply->rejected() && is_actor_creation_task) { - ResourceSet normal_task_resources = local_task_manager_->CalcNormalTaskResources(); - RAY_LOG(DEBUG) << "Reject leasing as the raylet has no enough resources." - << " actor_id = " << actor_id - << ", normal_task_resources = " << normal_task_resources.ToString() - << ", local_resoruce_view = " - << cluster_resource_scheduler_->GetClusterResourceManager() - .GetNodeResourceViewString( - scheduling::NodeID(self_node_id_.Binary())); - auto resources_data = reply->mutable_resources_data(); - resources_data->set_node_id(self_node_id_.Binary()); - resources_data->set_resources_normal_task_changed(true); - auto &normal_task_map = *(resources_data->mutable_resources_normal_task()); - normal_task_map = {normal_task_resources.GetResourceMap().begin(), - normal_task_resources.GetResourceMap().end()}; - resources_data->set_resources_normal_task_timestamp(absl::GetCurrentTimeNanos()); - } + auto send_reply_callback_wrapper = + [this, is_actor_creation_task, actor_id, reply, send_reply_callback]( + Status status, std::function success, std::function failure) { + // If resources are not enough due to normal tasks' preemption + // for GCS based actor scheduling, return a rejection + // with normal task resource usages so GCS can update + // its resource view of this raylet. + if (reply->rejected() && is_actor_creation_task) { + ResourceSet normal_task_resources = + local_task_manager_->CalcNormalTaskResources(); + RAY_LOG(DEBUG) << "Reject leasing as the raylet has no enough resources." + << " actor_id = " << actor_id << ", normal_task_resources = " + << normal_task_resources.ToString() << ", local_resoruce_view = " + << cluster_resource_scheduler_->GetClusterResourceManager() + .GetNodeResourceViewString( + scheduling::NodeID(self_node_id_.Binary())); + auto resources_data = reply->mutable_resources_data(); + resources_data->set_node_id(self_node_id_.Binary()); + resources_data->set_resources_normal_task_changed(true); + auto &normal_task_map = *(resources_data->mutable_resources_normal_task()); + normal_task_map = {normal_task_resources.GetResourceMap().begin(), + normal_task_resources.GetResourceMap().end()}; + resources_data->set_resources_normal_task_timestamp( + absl::GetCurrentTimeNanos()); + } - send_reply_callback(status, success, failure); - }; + send_reply_callback(status, success, failure); + }; - cluster_task_manager_->QueueAndScheduleTask(task, request.grant_or_reject(), + cluster_task_manager_->QueueAndScheduleTask(task, + request.grant_or_reject(), request.is_selected_based_on_locality(), - reply, send_reply_callback_wrapper); + reply, + send_reply_callback_wrapper); } void NodeManager::HandlePrepareBundleResources( const rpc::PrepareBundleResourcesRequest &request, - rpc::PrepareBundleResourcesReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::PrepareBundleResourcesReply *reply, + rpc::SendReplyCallback send_reply_callback) { std::vector> bundle_specs; for (int index = 0; index < request.bundle_specs_size(); index++) { bundle_specs.emplace_back( @@ -1632,7 +1703,8 @@ void NodeManager::HandlePrepareBundleResources( void NodeManager::HandleCommitBundleResources( const rpc::CommitBundleResourcesRequest &request, - rpc::CommitBundleResourcesReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::CommitBundleResourcesReply *reply, + rpc::SendReplyCallback send_reply_callback) { std::vector> bundle_specs; for (int index = 0; index < request.bundle_specs_size(); index++) { bundle_specs.emplace_back( @@ -1648,7 +1720,8 @@ void NodeManager::HandleCommitBundleResources( void NodeManager::HandleCancelResourceReserve( const rpc::CancelResourceReserveRequest &request, - rpc::CancelResourceReserveReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::CancelResourceReserveReply *reply, + rpc::SendReplyCallback send_reply_callback) { auto bundle_spec = BundleSpecification(request.bundle_spec()); RAY_LOG(DEBUG) << "Request to cancel reserved resource is received, " << bundle_spec.DebugString(); @@ -1750,7 +1823,8 @@ void NodeManager::HandleShutdownRaylet(const rpc::ShutdownRayletRequest &request void NodeManager::HandleReleaseUnusedWorkers( const rpc::ReleaseUnusedWorkersRequest &request, - rpc::ReleaseUnusedWorkersReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::ReleaseUnusedWorkersReply *reply, + rpc::SendReplyCallback send_reply_callback) { std::unordered_set in_use_worker_ids; for (int index = 0; index < request.worker_ids_in_use_size(); ++index) { auto worker_id = WorkerID::FromBinary(request.worker_ids_in_use(index)); @@ -1787,7 +1861,8 @@ void NodeManager::HandleCancelWorkerLease(const rpc::CancelWorkerLeaseRequest &r } void NodeManager::MarkObjectsAsFailed( - const ErrorType &error_type, const std::vector objects_to_fail, + const ErrorType &error_type, + const std::vector objects_to_fail, const JobID &job_id) { // TODO(swang): Ideally we should return the error directly to the client // that needs this object instead of storing the object in plasma, which is @@ -1801,8 +1876,12 @@ void NodeManager::MarkObjectsAsFailed( std::shared_ptr data; Status status; status = store_client_.TryCreateImmediately( - object_id, ref.owner_address(), 0, - reinterpret_cast(meta.c_str()), meta.length(), &data, + object_id, + ref.owner_address(), + 0, + reinterpret_cast(meta.c_str()), + meta.length(), + &data, plasma::flatbuf::ObjectSource::ErrorStoredByRaylet); if (status.ok()) { status = store_client_.Seal(object_id); @@ -1853,7 +1932,9 @@ void NodeManager::HandleDirectCallTaskUnblocked( void NodeManager::AsyncResolveObjects( const std::shared_ptr &client, const std::vector &required_object_refs, - const TaskID ¤t_task_id, bool ray_get, bool mark_worker_blocked) { + const TaskID ¤t_task_id, + bool ray_get, + bool mark_worker_blocked) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (!worker) { // The client is a driver. Drivers do not hold resources, so we simply mark @@ -1874,7 +1955,8 @@ void NodeManager::AsyncResolveObjects( } void NodeManager::AsyncResolveObjectsFinish( - const std::shared_ptr &client, const TaskID ¤t_task_id, + const std::shared_ptr &client, + const TaskID ¤t_task_id, bool was_blocked) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (!worker) { @@ -2166,8 +2248,8 @@ void NodeManager::HandlePinObjectIDs(const rpc::PinObjectIDsRequest &request, return; } // Wait for the object to be freed by the owner, which keeps the ref count. - local_object_manager_.PinObjectsAndWaitForFree(object_ids, std::move(results), - owner_address); + local_object_manager_.PinObjectsAndWaitForFree( + object_ids, std::move(results), owner_address); send_reply_callback(Status::OK(), nullptr, nullptr); } @@ -2179,7 +2261,8 @@ void NodeManager::HandleGetSystemConfig(const rpc::GetSystemConfigRequest &reque } void NodeManager::HandleGetGcsServerAddress( - const rpc::GetGcsServerAddressRequest &request, rpc::GetGcsServerAddressReply *reply, + const rpc::GetGcsServerAddressRequest &request, + rpc::GetGcsServerAddressReply *reply, rpc::SendReplyCallback send_reply_callback) { auto address = gcs_client_->GetGcsServerAddress(); reply->set_ip(address.first); @@ -2257,8 +2340,9 @@ void NodeManager::HandleGetNodeStats(const rpc::GetNodeStatsRequest &node_stats_ request.set_intended_worker_id(worker->WorkerId().Binary()); request.set_include_memory_info(node_stats_request.include_memory_info()); worker->rpc_client()->GetCoreWorkerStats( - request, [reply, worker, all_workers, driver_ids, send_reply_callback]( - const ray::Status &status, const rpc::GetCoreWorkerStatsReply &r) { + request, + [reply, worker, all_workers, driver_ids, send_reply_callback]( + const ray::Status &status, const rpc::GetCoreWorkerStatsReply &r) { reply->add_core_workers_stats()->MergeFrom(r.core_worker_stats()); reply->set_num_workers(reply->num_workers() + 1); if (reply->num_workers() == all_workers.size()) { @@ -2389,7 +2473,8 @@ std::string FormatMemoryInfo(std::vector node_stats) { void NodeManager::HandleFormatGlobalMemoryInfo( const rpc::FormatGlobalMemoryInfoRequest &request, - rpc::FormatGlobalMemoryInfoReply *reply, rpc::SendReplyCallback send_reply_callback) { + rpc::FormatGlobalMemoryInfoReply *reply, + rpc::SendReplyCallback send_reply_callback) { auto replies = std::make_shared>(); auto local_request = std::make_shared(); auto local_reply = std::make_shared(); @@ -2400,39 +2485,42 @@ void NodeManager::HandleFormatGlobalMemoryInfo( rpc::GetNodeStatsRequest stats_req; stats_req.set_include_memory_info(include_memory_info); - auto store_reply = [replies, reply, num_nodes, send_reply_callback, - include_memory_info](const rpc::GetNodeStatsReply &local_reply) { - replies->push_back(local_reply); - if (replies->size() >= num_nodes) { - if (include_memory_info) { - reply->set_memory_summary(FormatMemoryInfo(*replies)); - } - reply->mutable_store_stats()->CopyFrom(AccumulateStoreStats(*replies)); - send_reply_callback(Status::OK(), nullptr, nullptr); - } - }; + auto store_reply = + [replies, reply, num_nodes, send_reply_callback, include_memory_info]( + const rpc::GetNodeStatsReply &local_reply) { + replies->push_back(local_reply); + if (replies->size() >= num_nodes) { + if (include_memory_info) { + reply->set_memory_summary(FormatMemoryInfo(*replies)); + } + reply->mutable_store_stats()->CopyFrom(AccumulateStoreStats(*replies)); + send_reply_callback(Status::OK(), nullptr, nullptr); + } + }; // Fetch from remote nodes. for (const auto &entry : remote_node_manager_addresses_) { auto client = std::make_unique( entry.second.first, entry.second.second, client_call_manager_); - client->GetNodeStats( - stats_req, [replies, store_reply](const ray::Status &status, - const rpc::GetNodeStatsReply &r) { - if (!status.ok()) { - RAY_LOG(ERROR) << "Failed to get remote node stats: " << status.ToString(); - } - store_reply(r); - }); + client->GetNodeStats(stats_req, + [replies, store_reply](const ray::Status &status, + const rpc::GetNodeStatsReply &r) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to get remote node stats: " + << status.ToString(); + } + store_reply(r); + }); } // Fetch from the local node. - HandleGetNodeStats( - stats_req, local_reply.get(), - [local_reply, store_reply](Status status, std::function success, - std::function failure) mutable { - store_reply(*local_reply); - }); + HandleGetNodeStats(stats_req, + local_reply.get(), + [local_reply, store_reply](Status status, + std::function success, + std::function failure) mutable { + store_reply(*local_reply); + }); } void NodeManager::HandleGlobalGC(const rpc::GlobalGCRequest &request, @@ -2496,7 +2584,9 @@ void NodeManager::PublishInfeasibleTaskError(const RayTask &task) const { RAY_LOG(WARNING) << error_message_str; if (RayConfig::instance().legacy_scheduler_warnings()) { auto error_data_ptr = - gcs::CreateErrorTableData(type, error_message_str, current_time_ms(), + gcs::CreateErrorTableData(type, + error_message_str, + current_time_ms(), task.GetTaskSpecification().JobId()); RAY_CHECK_OK(gcs_client_->Errors().AsyncReportJobError(error_data_ptr, nullptr)); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index e37911809..5f7a603a5 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -143,7 +143,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// /// \param resource_config The initial set of node resources. /// \param object_manager A reference to the local object manager. - NodeManager(instrumented_io_context &io_service, const NodeID &self_node_id, + NodeManager(instrumented_io_context &io_service, + const NodeID &self_node_id, const NodeManagerConfig &config, const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client); @@ -163,7 +164,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param message_data A pointer to the message data. /// \return Void. void ProcessClientMessage(const std::shared_ptr &client, - int64_t message_type, const uint8_t *message_data); + int64_t message_type, + const uint8_t *message_data); /// Subscribe to the relevant GCS tables and set up handlers. /// @@ -296,7 +298,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return Void. void AsyncResolveObjects(const std::shared_ptr &client, const std::vector &required_object_refs, - const TaskID ¤t_task_id, bool ray_get, + const TaskID ¤t_task_id, + bool ray_get, bool mark_worker_blocked); /// Handle end of a blocking object get. This could be a task assigned to a @@ -311,7 +314,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// blocked in AsyncResolveObjects(). /// \return Void. void AsyncResolveObjectsFinish(const std::shared_ptr &client, - const TaskID ¤t_task_id, bool was_blocked); + const TaskID ¤t_task_id, + bool was_blocked); /// Handle a direct call task that is blocked. Note that this callback may /// arrive after the worker lease has been returned to the node manager. @@ -527,7 +531,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { rpc::SendReplyCallback send_reply_callback) override; /// Handle a `GlobalGC` request. - void HandleGlobalGC(const rpc::GlobalGCRequest &request, rpc::GlobalGCReply *reply, + void HandleGlobalGC(const rpc::GlobalGCRequest &request, + rpc::GlobalGCReply *reply, rpc::SendReplyCallback send_reply_callback) override; /// Handle a `FormatGlobalMemoryInfo`` request. diff --git a/src/ray/raylet/placement_group_resource_manager.cc b/src/ray/raylet/placement_group_resource_manager.cc index b6bbc3cad..b3e0c7d0a 100644 --- a/src/ray/raylet/placement_group_resource_manager.cc +++ b/src/ray/raylet/placement_group_resource_manager.cc @@ -67,8 +67,9 @@ bool NewPlacementGroupResourceManager::PrepareBundle( auto bundle_state = std::make_shared(CommitState::PREPARED, resource_instances); pg_bundles_[bundle_spec.BundleId()] = bundle_state; - bundle_spec_map_.emplace(bundle_spec.BundleId(), std::make_shared( - bundle_spec.GetMessage())); + bundle_spec_map_.emplace( + bundle_spec.BundleId(), + std::make_shared(bundle_spec.GetMessage())); return true; } diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index e5ba06ec1..8293e6e5b 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -25,7 +25,8 @@ namespace { const std::vector GenerateEnumNames(const char *const *enum_names_ptr, - int start_index, int end_index) { + int start_index, + int end_index) { std::vector enum_names; for (int i = 0; i < start_index; ++i) { enum_names.push_back("EmptyMessageType"); @@ -54,19 +55,24 @@ namespace ray { namespace raylet { -Raylet::Raylet(instrumented_io_context &main_service, const std::string &socket_name, +Raylet::Raylet(instrumented_io_context &main_service, + const std::string &socket_name, const std::string &node_ip_address, const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, - std::shared_ptr gcs_client, int metrics_export_port) + std::shared_ptr gcs_client, + int metrics_export_port) : main_service_(main_service), self_node_id_( !RayConfig::instance().OVERRIDE_NODE_ID_FOR_TESTING().empty() ? NodeID::FromHex(RayConfig::instance().OVERRIDE_NODE_ID_FOR_TESTING()) : NodeID::FromRandom()), gcs_client_(gcs_client), - node_manager_(main_service, self_node_id_, node_manager_config, - object_manager_config, gcs_client_), + node_manager_(main_service, + self_node_id_, + node_manager_config, + object_manager_config, + gcs_client_), socket_name_(socket_name), acceptor_(main_service, ParseUrlEndpoint(socket_name)), socket_(main_service) { @@ -118,8 +124,9 @@ ray::Status Raylet::RegisterGcs() { } void Raylet::DoAccept() { - acceptor_.async_accept(socket_, boost::bind(&Raylet::HandleAccept, this, - boost::asio::placeholders::error)); + acceptor_.async_accept( + socket_, + boost::bind(&Raylet::HandleAccept, this, boost::asio::placeholders::error)); } void Raylet::HandleAccept(const boost::system::error_code &error) { @@ -141,9 +148,13 @@ void Raylet::HandleAccept(const boost::system::error_code &error) { fbb.GetBufferPointer() + fbb.GetSize()); // Accept a new local client and dispatch it to the node manager. auto new_connection = ClientConnection::Create( - client_handler, message_handler, std::move(socket_), "worker", + client_handler, + message_handler, + std::move(socket_), + "worker", node_manager_message_enum, - static_cast(protocol::MessageType::DisconnectClient), message_data); + static_cast(protocol::MessageType::DisconnectClient), + message_data); } // We're ready to accept another client. DoAccept(); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 1f770d0eb..747a9578c 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -47,10 +47,13 @@ class Raylet { /// manager. /// \param gcs_client A client connection to the GCS. /// \param metrics_export_port A port at which metrics are exposed to. - Raylet(instrumented_io_context &main_service, const std::string &socket_name, - const std::string &node_ip_address, const NodeManagerConfig &node_manager_config, + Raylet(instrumented_io_context &main_service, + const std::string &socket_name, + const std::string &node_ip_address, + const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, - std::shared_ptr gcs_client, int metrics_export_port); + std::shared_ptr gcs_client, + int metrics_export_port); /// Start this raylet. void Start(); diff --git a/src/ray/raylet/scheduling/cluster_resource_data.cc b/src/ray/raylet/scheduling/cluster_resource_data.cc index 525c4c7a1..dc4d6fb79 100644 --- a/src/ray/raylet/scheduling/cluster_resource_data.cc +++ b/src/ray/raylet/scheduling/cluster_resource_data.cc @@ -20,9 +20,10 @@ namespace ray { using namespace ::ray::scheduling; -const std::string resource_labels[] = { - ray::kCPU_ResourceLabel, ray::kMemory_ResourceLabel, ray::kGPU_ResourceLabel, - ray::kObjectStoreMemory_ResourceLabel}; +const std::string resource_labels[] = {ray::kCPU_ResourceLabel, + ray::kMemory_ResourceLabel, + ray::kGPU_ResourceLabel, + ray::kObjectStoreMemory_ResourceLabel}; const std::string ResourceEnumToString(PredefinedResources resource) { // TODO (Alex): We should replace this with a protobuf enum. diff --git a/src/ray/raylet/scheduling/cluster_resource_manager.cc b/src/ray/raylet/scheduling/cluster_resource_manager.cc index 8475de504..7167ba7b0 100644 --- a/src/ray/raylet/scheduling/cluster_resource_manager.cc +++ b/src/ray/raylet/scheduling/cluster_resource_manager.cc @@ -209,8 +209,9 @@ bool ClusterResourceManager::SubtractNodeAvailableResources( for (size_t i = 0; i < PredefinedResources_MAX; i++) { resources->predefined_resources[i].available = - std::max(FixedPoint(0), resources->predefined_resources[i].available - - resource_request.predefined_resources[i]); + std::max(FixedPoint(0), + resources->predefined_resources[i].available - + resource_request.predefined_resources[i]); } for (const auto &task_req_custom_resource : resource_request.custom_resources) { @@ -229,7 +230,8 @@ bool ClusterResourceManager::SubtractNodeAvailableResources( } bool ClusterResourceManager::HasSufficientResource( - scheduling::NodeID node_id, const ResourceRequest &resource_request, + scheduling::NodeID node_id, + const ResourceRequest &resource_request, bool ignore_object_store_memory_requirement) const { auto it = nodes_.find(node_id); if (it == nodes_.end()) { diff --git a/src/ray/raylet/scheduling/cluster_resource_manager.h b/src/ray/raylet/scheduling/cluster_resource_manager.h index c95f5f9ec..7f7fcbd53 100644 --- a/src/ray/raylet/scheduling/cluster_resource_manager.h +++ b/src/ray/raylet/scheduling/cluster_resource_manager.h @@ -66,7 +66,8 @@ class ClusterResourceManager { /// \param resource_id: Resource which we want to update. /// \param resource_total: New capacity of the resource. void UpdateResourceCapacity(scheduling::NodeID node_id, - scheduling::ResourceID resource_id, double resource_total); + scheduling::ResourceID resource_id, + double resource_total); /// Delete a given resource from a given node. /// diff --git a/src/ray/raylet/scheduling/cluster_resource_manager_test.cc b/src/ray/raylet/scheduling/cluster_resource_manager_test.cc index 96f67e9e2..fb6b29b7e 100644 --- a/src/ray/raylet/scheduling/cluster_resource_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_resource_manager_test.cc @@ -18,7 +18,8 @@ namespace ray { -NodeResources CreateNodeResources(double available_cpu, double total_cpu, +NodeResources CreateNodeResources(double available_cpu, + double total_cpu, double available_custom_resource = 0, double total_custom_resource = 0, bool object_pulls_queued = false) { @@ -36,13 +37,17 @@ struct ClusterResourceManagerTest : public ::testing::Test { manager = std::make_unique(); manager->AddOrUpdateNode(node0, CreateNodeResources(/*available_cpu*/ 1, /*total_cpu*/ 1)); - manager->AddOrUpdateNode( - node1, CreateNodeResources(/*available_cpu*/ 0, /*total_cpu*/ 0, - /*available_custom*/ 1, /*total_custom*/ 1)); - manager->AddOrUpdateNode( - node2, CreateNodeResources(/*available_cpu*/ 1, /*total_cpu*/ 1, - /*available_custom*/ 1, /*total_custom*/ 1, - /*object_pulls_queued*/ true)); + manager->AddOrUpdateNode(node1, + CreateNodeResources(/*available_cpu*/ 0, + /*total_cpu*/ 0, + /*available_custom*/ 1, + /*total_custom*/ 1)); + manager->AddOrUpdateNode(node2, + CreateNodeResources(/*available_cpu*/ 1, + /*total_cpu*/ 1, + /*available_custom*/ 1, + /*total_custom*/ 1, + /*object_pulls_queued*/ true)); } scheduling::NodeID node0 = scheduling::NodeID(0); scheduling::NodeID node1 = scheduling::NodeID(1); diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc index 09524d24a..4772ad449 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc @@ -29,53 +29,65 @@ ClusterResourceScheduler::ClusterResourceScheduler() NodeResources node_resources; node_resources.predefined_resources.resize(PredefinedResources_MAX); local_resource_manager_ = std::make_unique( - local_node_id_, node_resources, - /*get_used_object_store_memory*/ nullptr, /*get_pull_manager_at_capacity*/ nullptr, + local_node_id_, + node_resources, + /*get_used_object_store_memory*/ nullptr, + /*get_pull_manager_at_capacity*/ nullptr, [&](const NodeResources &local_resource_update) { cluster_resource_manager_->AddOrUpdateNode(local_node_id_, local_resource_update); }); scheduling_policy_ = std::make_unique( - local_node_id_, cluster_resource_manager_->GetResourceView(), + local_node_id_, + cluster_resource_manager_->GetResourceView(), [this](auto node_id) { return this->NodeAlive(node_id); }); } ClusterResourceScheduler::ClusterResourceScheduler( - scheduling::NodeID local_node_id, const NodeResources &local_node_resources, + scheduling::NodeID local_node_id, + const NodeResources &local_node_resources, gcs::GcsClient &gcs_client) : local_node_id_(local_node_id), gcs_client_(&gcs_client) { cluster_resource_manager_ = std::make_unique(); local_resource_manager_ = std::make_unique( - local_node_id, local_node_resources, - /*get_used_object_store_memory*/ nullptr, /*get_pull_manager_at_capacity*/ nullptr, + local_node_id, + local_node_resources, + /*get_used_object_store_memory*/ nullptr, + /*get_pull_manager_at_capacity*/ nullptr, [&](const NodeResources &local_resource_update) { cluster_resource_manager_->AddOrUpdateNode(local_node_id_, local_resource_update); }); cluster_resource_manager_->AddOrUpdateNode(local_node_id_, local_node_resources); scheduling_policy_ = std::make_unique( - local_node_id_, cluster_resource_manager_->GetResourceView(), + local_node_id_, + cluster_resource_manager_->GetResourceView(), [this](auto node_id) { return this->NodeAlive(node_id); }); } ClusterResourceScheduler::ClusterResourceScheduler( scheduling::NodeID local_node_id, const absl::flat_hash_map &local_node_resources, - gcs::GcsClient &gcs_client, std::function get_used_object_store_memory, + gcs::GcsClient &gcs_client, + std::function get_used_object_store_memory, std::function get_pull_manager_at_capacity) : local_node_id_(local_node_id), gcs_client_(&gcs_client) { NodeResources node_resources = ResourceMapToNodeResources(local_node_resources, local_node_resources); cluster_resource_manager_ = std::make_unique(); local_resource_manager_ = std::make_unique( - local_node_id_, node_resources, get_used_object_store_memory, - get_pull_manager_at_capacity, [&](const NodeResources &local_resource_update) { + local_node_id_, + node_resources, + get_used_object_store_memory, + get_pull_manager_at_capacity, + [&](const NodeResources &local_resource_update) { cluster_resource_manager_->AddOrUpdateNode(local_node_id_, local_resource_update); }); cluster_resource_manager_->AddOrUpdateNode(local_node_id_, node_resources); scheduling_policy_ = std::make_unique( - local_node_id_, cluster_resource_manager_->GetResourceView(), + local_node_id_, + cluster_resource_manager_->GetResourceView(), [this](auto node_id) { return this->NodeAlive(node_id); }); } @@ -95,14 +107,18 @@ bool ClusterResourceScheduler::IsSchedulable(const ResourceRequest &resource_req // will eventually spill the task back from the waiting queue if its args // cannot be pulled. return cluster_resource_manager_->HasSufficientResource( - node_id, resource_request, + node_id, + resource_request, /*ignore_object_store_memory_requirement*/ node_id == local_node_id_); } scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode( const ResourceRequest &resource_request, - const rpc::SchedulingStrategy &scheduling_strategy, bool actor_creation, - bool force_spillback, int64_t *total_violations, bool *is_infeasible) { + const rpc::SchedulingStrategy &scheduling_strategy, + bool actor_creation, + bool force_spillback, + int64_t *total_violations, + bool *is_infeasible) { // The zero cpu actor is a special case that must be handled the same way by all // scheduling policies. if (actor_creation && resource_request.IsEmpty()) { @@ -112,17 +128,19 @@ scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode( auto best_node_id = scheduling::NodeID::Nil(); if (scheduling_strategy.scheduling_strategy_case() == rpc::SchedulingStrategy::SchedulingStrategyCase::kSpreadSchedulingStrategy) { - best_node_id = scheduling_policy_->Schedule( - resource_request, SchedulingOptions::Spread( - /*avoid_local_node*/ force_spillback, - /*require_node_available*/ force_spillback)); + best_node_id = + scheduling_policy_->Schedule(resource_request, + SchedulingOptions::Spread( + /*avoid_local_node*/ force_spillback, + /*require_node_available*/ force_spillback)); } else { // TODO (Alex): Setting require_available == force_spillback is a hack in order to // remain bug compatible with the legacy scheduling algorithms. - best_node_id = scheduling_policy_->Schedule( - resource_request, SchedulingOptions::Hybrid( - /*avoid_local_node*/ force_spillback, - /*require_node_available*/ force_spillback)); + best_node_id = + scheduling_policy_->Schedule(resource_request, + SchedulingOptions::Hybrid( + /*avoid_local_node*/ force_spillback, + /*require_node_available*/ force_spillback)); } *is_infeasible = best_node_id.IsNil(); @@ -142,13 +160,20 @@ scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode( scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode( const absl::flat_hash_map &task_resources, - const rpc::SchedulingStrategy &scheduling_strategy, bool requires_object_store_memory, - bool actor_creation, bool force_spillback, int64_t *total_violations, + const rpc::SchedulingStrategy &scheduling_strategy, + bool requires_object_store_memory, + bool actor_creation, + bool force_spillback, + int64_t *total_violations, bool *is_infeasible) { ResourceRequest resource_request = ResourceMapToResourceRequest(task_resources, requires_object_store_memory); - return GetBestSchedulableNode(resource_request, scheduling_strategy, actor_creation, - force_spillback, total_violations, is_infeasible); + return GetBestSchedulableNode(resource_request, + scheduling_strategy, + actor_creation, + force_spillback, + total_violations, + is_infeasible); } bool ClusterResourceScheduler::SubtractRemoteNodeAvailableResources( @@ -188,8 +213,11 @@ bool ClusterResourceScheduler::IsSchedulableOnNode( } scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode( - const TaskSpecification &task_spec, bool prioritize_local_node, - bool exclude_local_node, bool requires_object_store_memory, bool *is_infeasible) { + const TaskSpecification &task_spec, + bool prioritize_local_node, + bool exclude_local_node, + bool requires_object_store_memory, + bool *is_infeasible) { // If the local node is available, we should directly return it instead of // going through the full hybrid policy since we don't want spillback. if (prioritize_local_node && !exclude_local_node && @@ -203,8 +231,12 @@ scheduling::NodeID ClusterResourceScheduler::GetBestSchedulableNode( int64_t _unused; return GetBestSchedulableNode( task_spec.GetRequiredPlacementResources().GetResourceMap(), - task_spec.GetMessage().scheduling_strategy(), requires_object_store_memory, - task_spec.IsActorCreationTask(), exclude_local_node, &_unused, is_infeasible); + task_spec.GetMessage().scheduling_strategy(), + requires_object_store_memory, + task_spec.IsActorCreationTask(), + exclude_local_node, + &_unused, + is_infeasible); } } // namespace ray diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.h b/src/ray/raylet/scheduling/cluster_resource_scheduler.h index d8d716aa3..7bdca5db3 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.h +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.h @@ -148,8 +148,11 @@ class ClusterResourceScheduler { /// return the ID of a node that can schedule the resource request. scheduling::NodeID GetBestSchedulableNode( const ResourceRequest &resource_request, - const rpc::SchedulingStrategy &scheduling_strategy, bool actor_creation, - bool force_spillback, int64_t *violations, bool *is_infeasible); + const rpc::SchedulingStrategy &scheduling_strategy, + bool actor_creation, + bool force_spillback, + int64_t *violations, + bool *is_infeasible); /// Similar to /// int64_t GetBestSchedulableNode(...) @@ -160,8 +163,11 @@ class ClusterResourceScheduler { scheduling::NodeID GetBestSchedulableNode( const absl::flat_hash_map &resource_request, const rpc::SchedulingStrategy &scheduling_strategy, - bool requires_object_store_memory, bool actor_creation, bool force_spillback, - int64_t *violations, bool *is_infeasible); + bool requires_object_store_memory, + bool actor_creation, + bool force_spillback, + int64_t *violations, + bool *is_infeasible); /// Identifier of local node. scheduling::NodeID local_node_id_; diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc index b4135fc3e..1e4a7b790 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc @@ -52,8 +52,10 @@ vector EmptyIntVector; vector EmptyBoolVector; vector EmptyFixedPointVector; -void initResourceRequest(ResourceRequest &res_request, vector pred_demands, - vector cust_ids, vector cust_demands) { +void initResourceRequest(ResourceRequest &res_request, + vector pred_demands, + vector cust_ids, + vector cust_demands) { res_request.predefined_resources.resize(PredefinedResources_MAX + pred_demands.size()); for (size_t i = 0; i < pred_demands.size(); i++) { res_request.predefined_resources[i] = pred_demands[i]; @@ -68,7 +70,9 @@ void initResourceRequest(ResourceRequest &res_request, vector pred_d } }; -void addTaskResourceInstances(bool predefined, vector allocation, uint64_t idx, +void addTaskResourceInstances(bool predefined, + vector allocation, + uint64_t idx, TaskResourceInstances *task_allocation) { std::vector allocation_fp = VectorDoubleToVectorFixedPoint(allocation); @@ -83,8 +87,10 @@ void addTaskResourceInstances(bool predefined, vector allocation, uint64 } }; -void initNodeResources(NodeResources &node, vector &pred_capacities, - vector &cust_ids, vector &cust_capacities) { +void initNodeResources(NodeResources &node, + vector &pred_capacities, + vector &cust_ids, + vector &cust_capacities) { for (size_t i = 0; i < pred_capacities.size(); i++) { ResourceCapacity rc; rc.total = rc.available = pred_capacities[i]; @@ -371,8 +377,8 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingModifyClusterNodeTest) { TEST_F(ClusterResourceSchedulerTest, SpreadSchedulingStrategyTest) { absl::flat_hash_map resource_total({{"CPU", 10}}); auto local_node_id = scheduling::NodeID(NodeID::FromRandom().Binary()); - ClusterResourceScheduler resource_scheduler(local_node_id, resource_total, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + local_node_id, resource_total, *gcs_client_); AssertPredefinedNodeResources(); auto remote_node_id = scheduling::NodeID(NodeID::FromRandom().Binary()); resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( @@ -383,15 +389,23 @@ TEST_F(ClusterResourceSchedulerTest, SpreadSchedulingStrategyTest) { bool is_infeasible; rpc::SchedulingStrategy scheduling_strategy; scheduling_strategy.mutable_spread_scheduling_strategy(); - auto node_id_1 = resource_scheduler.GetBestSchedulableNode( - resource_request, scheduling_strategy, false, false, false, &violations, - &is_infeasible); + auto node_id_1 = resource_scheduler.GetBestSchedulableNode(resource_request, + scheduling_strategy, + false, + false, + false, + &violations, + &is_infeasible); absl::flat_hash_map resource_available({{"CPU", 9}}); resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( node_id_1, resource_total, resource_available); - auto node_id_2 = resource_scheduler.GetBestSchedulableNode( - resource_request, scheduling_strategy, false, false, false, &violations, - &is_infeasible); + auto node_id_2 = resource_scheduler.GetBestSchedulableNode(resource_request, + scheduling_strategy, + false, + false, + false, + &violations, + &is_infeasible); ASSERT_EQ((std::set{node_id_1, node_id_2}), (std::set{local_node_id, remote_node_id})); } @@ -403,8 +417,8 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingUpdateAvailableResourcesTest) { vector cust_ids{1, 2}; vector cust_capacities{5, 5}; initNodeResources(node_resources, pred_capacities, cust_ids, cust_capacities); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(1), node_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(1), node_resources, *gcs_client_); AssertPredefinedNodeResources(); { @@ -457,8 +471,8 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingUpdateTotalResourcesTest) { absl::flat_hash_map initial_resources = { {ray::kCPU_ResourceLabel, 1}, {"custom", 1}}; std::string name = NodeID::FromRandom().Binary(); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(name), initial_resources, - *gcs_client_, nullptr, nullptr); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(name), initial_resources, *gcs_client_, nullptr, nullptr); resource_scheduler.GetLocalResourceManager().AddLocalResourceInstances( scheduling::ResourceID(ray::kCPU_ResourceLabel), {0, 1, 1}); @@ -528,8 +542,8 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingResourceRequestTest) { vector cust_ids{1}; vector cust_capacities{10}; initNodeResources(node_resources, pred_capacities, cust_ids, cust_capacities); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); auto node_id = NodeID::FromRandom(); rpc::SchedulingStrategy scheduling_strategy; scheduling_strategy.mutable_default_scheduling_strategy(); @@ -546,8 +560,8 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingResourceRequestTest) { { ResourceRequest resource_request; vector pred_demands = {11}; - initResourceRequest(resource_request, pred_demands, EmptyIntVector, - EmptyFixedPointVector); + initResourceRequest( + resource_request, pred_demands, EmptyIntVector, EmptyFixedPointVector); int64_t violations; bool is_infeasible; auto node_id = resource_scheduler.GetBestSchedulableNode( @@ -559,8 +573,8 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingResourceRequestTest) { { ResourceRequest resource_request; vector pred_demands = {5}; - initResourceRequest(resource_request, pred_demands, EmptyIntVector, - EmptyFixedPointVector); + initResourceRequest( + resource_request, pred_demands, EmptyIntVector, EmptyFixedPointVector); int64_t violations; bool is_infeasible; auto node_id = resource_scheduler.GetBestSchedulableNode( @@ -637,8 +651,8 @@ TEST_F(ClusterResourceSchedulerTest, GetLocalAvailableResourcesWithCpuUnitTest) vector cust_ids{1}; vector cust_capacities{8}; initNodeResources(node_resources, pred_capacities, cust_ids, cust_capacities); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); TaskResourceInstances available_cluster_resources = resource_scheduler.GetLocalResourceManager() @@ -670,8 +684,8 @@ TEST_F(ClusterResourceSchedulerTest, GetLocalAvailableResourcesTest) { vector cust_ids{1}; vector cust_capacities{8}; initNodeResources(node_resources, pred_capacities, cust_ids, cust_capacities); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); TaskResourceInstances available_cluster_resources = resource_scheduler.GetLocalResourceManager() @@ -705,8 +719,8 @@ TEST_F(ClusterResourceSchedulerTest, GetCPUInstancesDoubleTest) { TEST_F(ClusterResourceSchedulerTest, AvailableResourceInstancesOpsTest) { NodeResources node_resources; vector pred_capacities{3 /* CPU */}; - initNodeResources(node_resources, pred_capacities, EmptyIntVector, - EmptyFixedPointVector); + initNodeResources( + node_resources, pred_capacities, EmptyIntVector, EmptyFixedPointVector); ClusterResourceScheduler cluster(scheduling::NodeID(0), node_resources, *gcs_client_); ResourceInstanceCapacities instances; @@ -738,15 +752,15 @@ TEST_F(ClusterResourceSchedulerTest, TaskResourceInstancesTest) { { NodeResources node_resources; vector pred_capacities{3. /* CPU */, 4. /* MEM */, 5. /* GPU */}; - initNodeResources(node_resources, pred_capacities, EmptyIntVector, - EmptyFixedPointVector); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + initNodeResources( + node_resources, pred_capacities, EmptyIntVector, EmptyFixedPointVector); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); ResourceRequest resource_request; vector pred_demands = {3. /* CPU */, 2. /* MEM */, 1.5 /* GPU */}; - initResourceRequest(resource_request, pred_demands, EmptyIntVector, - EmptyFixedPointVector); + initResourceRequest( + resource_request, pred_demands, EmptyIntVector, EmptyFixedPointVector); NodeResourceInstances old_local_resources = resource_scheduler.GetLocalResourceManager().GetLocalResources(); @@ -771,15 +785,15 @@ TEST_F(ClusterResourceSchedulerTest, TaskResourceInstancesTest) { { NodeResources node_resources; vector pred_capacities{3 /* CPU */, 4 /* MEM */, 5 /* GPU */}; - initNodeResources(node_resources, pred_capacities, EmptyIntVector, - EmptyFixedPointVector); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + initNodeResources( + node_resources, pred_capacities, EmptyIntVector, EmptyFixedPointVector); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); ResourceRequest resource_request; vector pred_demands = {4. /* CPU */, 2. /* MEM */, 1.5 /* GPU */}; - initResourceRequest(resource_request, pred_demands, EmptyIntVector, - EmptyFixedPointVector); + initResourceRequest( + resource_request, pred_demands, EmptyIntVector, EmptyFixedPointVector); NodeResourceInstances old_local_resources = resource_scheduler.GetLocalResourceManager().GetLocalResources(); @@ -802,8 +816,8 @@ TEST_F(ClusterResourceSchedulerTest, TaskResourceInstancesTest) { vector cust_ids{1, 2}; vector cust_capacities{4, 4}; initNodeResources(node_resources, pred_capacities, cust_ids, cust_capacities); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); ResourceRequest resource_request; vector pred_demands = {3. /* CPU */, 2. /* MEM */, 1.5 /* GPU */}; @@ -835,8 +849,8 @@ TEST_F(ClusterResourceSchedulerTest, TaskResourceInstancesTest) { vector cust_ids{1, 2}; vector cust_capacities{4, 4}; initNodeResources(node_resources, pred_capacities, cust_ids, cust_capacities); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); ResourceRequest resource_request; vector pred_demands = {3. /* CPU */, 2. /* MEM */, 1.5 /* GPU */}; @@ -865,8 +879,8 @@ TEST_F(ClusterResourceSchedulerTest, TaskResourceInstancesAllocationFailureTest) vector cust_ids{1, 2, 3}; vector cust_capacities{4, 4, 4}; initNodeResources(node_resources, pred_capacities, cust_ids, cust_capacities); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); ResourceRequest resource_request; vector pred_demands = {0. /* CPU */, 0. /* MEM */, 0. /* GPU */}; @@ -896,8 +910,8 @@ TEST_F(ClusterResourceSchedulerTest, TaskResourceInstancesTest2) { vector cust_ids{1, 2}; vector cust_capacities{4., 4.}; initNodeResources(node_resources, pred_capacities, cust_ids, cust_capacities); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); ResourceRequest resource_request; vector pred_demands = {2. /* CPU */, 2. /* MEM */, 1.5 /* GPU */}; @@ -939,15 +953,24 @@ TEST_F(ClusterResourceSchedulerTest, DeadNodeTest) { rpc::SchedulingStrategy scheduling_strategy; scheduling_strategy.mutable_default_scheduling_strategy(); ASSERT_EQ(scheduling::NodeID(node_id.Binary()), - resource_scheduler.GetBestSchedulableNode(resource, scheduling_strategy, - false, false, false, &violations, + resource_scheduler.GetBestSchedulableNode(resource, + scheduling_strategy, + false, + false, + false, + &violations, &is_infeasible)); EXPECT_CALL(*gcs_client_->mock_node_accessor, Get(node_id, ::testing::_)) .WillOnce(::testing::Return(nullptr)) .WillOnce(::testing::Return(nullptr)); ASSERT_TRUE(resource_scheduler - .GetBestSchedulableNode(resource, scheduling_strategy, false, false, - false, &violations, &is_infeasible) + .GetBestSchedulableNode(resource, + scheduling_strategy, + false, + false, + false, + &violations, + &is_infeasible) .IsNil()); } @@ -958,8 +981,8 @@ TEST_F(ClusterResourceSchedulerTest, TaskGPUResourceInstancesTest) { vector cust_ids{1}; vector cust_capacities{8}; initNodeResources(node_resources, pred_capacities, cust_ids, cust_capacities); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); std::vector allocate_gpu_instances{0.5, 0.5, 0.5, 0.5}; resource_scheduler.GetLocalResourceManager().SubtractResourceInstances( @@ -970,7 +993,8 @@ TEST_F(ClusterResourceSchedulerTest, TaskGPUResourceInstancesTest) { .GetAvailableResourceInstances() .GetGPUInstancesDouble(); std::vector expected_available_gpu_instances{0.5, 0.5, 0.5, 0.5}; - ASSERT_TRUE(std::equal(available_gpu_instances.begin(), available_gpu_instances.end(), + ASSERT_TRUE(std::equal(available_gpu_instances.begin(), + available_gpu_instances.end(), expected_available_gpu_instances.begin())); resource_scheduler.GetLocalResourceManager().AddResourceInstances( @@ -980,7 +1004,8 @@ TEST_F(ClusterResourceSchedulerTest, TaskGPUResourceInstancesTest) { .GetAvailableResourceInstances() .GetGPUInstancesDouble(); expected_available_gpu_instances = {1., 1., 1., 1.}; - ASSERT_TRUE(std::equal(available_gpu_instances.begin(), available_gpu_instances.end(), + ASSERT_TRUE(std::equal(available_gpu_instances.begin(), + available_gpu_instances.end(), expected_available_gpu_instances.begin())); allocate_gpu_instances = {1.5, 1.5, .5, 1.5}; @@ -995,7 +1020,8 @@ TEST_F(ClusterResourceSchedulerTest, TaskGPUResourceInstancesTest) { .GetAvailableResourceInstances() .GetGPUInstancesDouble(); expected_available_gpu_instances = {0., 0., 0.5, 0.}; - ASSERT_TRUE(std::equal(available_gpu_instances.begin(), available_gpu_instances.end(), + ASSERT_TRUE(std::equal(available_gpu_instances.begin(), + available_gpu_instances.end(), expected_available_gpu_instances.begin())); allocate_gpu_instances = {1.0, .5, 1., .5}; @@ -1009,7 +1035,8 @@ TEST_F(ClusterResourceSchedulerTest, TaskGPUResourceInstancesTest) { .GetAvailableResourceInstances() .GetGPUInstancesDouble(); expected_available_gpu_instances = {1., .5, 1., .5}; - ASSERT_TRUE(std::equal(available_gpu_instances.begin(), available_gpu_instances.end(), + ASSERT_TRUE(std::equal(available_gpu_instances.begin(), + available_gpu_instances.end(), expected_available_gpu_instances.begin())); } } @@ -1022,8 +1049,8 @@ TEST_F(ClusterResourceSchedulerTest, vector cust_ids{1}; vector cust_capacities{8}; initNodeResources(node_resources, pred_capacities, cust_ids, cust_capacities); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); { std::vector allocate_gpu_instances{0.5, 0.5, 2, 0.5}; @@ -1074,15 +1101,15 @@ TEST_F(ClusterResourceSchedulerTest, TEST_F(ClusterResourceSchedulerTest, TaskResourceInstanceWithHardRequestTest) { NodeResources node_resources; vector pred_capacities{4. /* CPU */, 2. /* MEM */, 4. /* GPU */}; - initNodeResources(node_resources, pred_capacities, EmptyIntVector, - EmptyFixedPointVector); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + initNodeResources( + node_resources, pred_capacities, EmptyIntVector, EmptyFixedPointVector); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); ResourceRequest resource_request; vector pred_demands = {2. /* CPU */, 2. /* MEM */, 1.5 /* GPU */}; - initResourceRequest(resource_request, pred_demands, EmptyIntVector, - EmptyFixedPointVector); + initResourceRequest( + resource_request, pred_demands, EmptyIntVector, EmptyFixedPointVector); std::shared_ptr task_allocation = std::make_shared(); @@ -1101,15 +1128,15 @@ TEST_F(ClusterResourceSchedulerTest, TaskResourceInstanceWithHardRequestTest) { TEST_F(ClusterResourceSchedulerTest, TaskResourceInstanceWithoutCpuUnitTest) { NodeResources node_resources; vector pred_capacities{4. /* CPU */, 2. /* MEM */, 4. /* GPU */}; - initNodeResources(node_resources, pred_capacities, EmptyIntVector, - EmptyFixedPointVector); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID(0), node_resources, - *gcs_client_); + initNodeResources( + node_resources, pred_capacities, EmptyIntVector, EmptyFixedPointVector); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID(0), node_resources, *gcs_client_); ResourceRequest resource_request; vector pred_demands = {2. /* CPU */, 2. /* MEM */, 1.5 /* GPU */}; - initResourceRequest(resource_request, pred_demands, EmptyIntVector, - EmptyFixedPointVector); + initResourceRequest( + resource_request, pred_demands, EmptyIntVector, EmptyFixedPointVector); std::shared_ptr task_allocation = std::make_shared(); @@ -1141,8 +1168,13 @@ TEST_F(ClusterResourceSchedulerTest, TestAlwaysSpillInfeasibleTask) { rpc::SchedulingStrategy scheduling_strategy; scheduling_strategy.mutable_default_scheduling_strategy(); ASSERT_TRUE(resource_scheduler - .GetBestSchedulableNode(resource_spec, scheduling_strategy, false, - false, false, &total_violations, &is_infeasible) + .GetBestSchedulableNode(resource_spec, + scheduling_strategy, + false, + false, + false, + &total_violations, + &is_infeasible) .IsNil()); // Feasible remote node, but doesn't currently have resources available. We @@ -1150,18 +1182,28 @@ TEST_F(ClusterResourceSchedulerTest, TestAlwaysSpillInfeasibleTask) { auto remote_feasible = scheduling::NodeID(NodeID::FromRandom().Binary()); resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( remote_feasible, resource_spec, {{"CPU", 0.}}); - ASSERT_EQ(remote_feasible, resource_scheduler.GetBestSchedulableNode( - resource_spec, scheduling_strategy, false, false, false, - &total_violations, &is_infeasible)); + ASSERT_EQ(remote_feasible, + resource_scheduler.GetBestSchedulableNode(resource_spec, + scheduling_strategy, + false, + false, + false, + &total_violations, + &is_infeasible)); // Feasible remote node, and it currently has resources available. We should // prefer to spill there. auto remote_available = scheduling::NodeID(NodeID::FromRandom().Binary()); resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( remote_available, resource_spec, resource_spec); - ASSERT_EQ(remote_available, resource_scheduler.GetBestSchedulableNode( - resource_spec, scheduling_strategy, false, false, false, - &total_violations, &is_infeasible)); + ASSERT_EQ(remote_available, + resource_scheduler.GetBestSchedulableNode(resource_spec, + scheduling_strategy, + false, + false, + false, + &total_violations, + &is_infeasible)); } TEST_F(ClusterResourceSchedulerTest, ResourceUsageReportTest) { @@ -1171,13 +1213,13 @@ TEST_F(ClusterResourceSchedulerTest, ResourceUsageReportTest) { absl::flat_hash_map initial_resources( {{"CPU", 1}, {"GPU", 2}, {"memory", 3}, {"1", 1}, {"2", 2}, {"3", 3}}); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID("0"), initial_resources, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID("0"), initial_resources, *gcs_client_); NodeResources other_node_resources; vector other_pred_capacities{1. /* CPU */, 1. /* MEM */, 1. /* GPU */}; vector other_cust_capacities{5., 4., 3., 2., 1.}; - initNodeResources(other_node_resources, other_pred_capacities, cust_ids, - other_cust_capacities); + initNodeResources( + other_node_resources, other_pred_capacities, cust_ids, other_cust_capacities); resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( scheduling::NodeID(12345), other_node_resources); @@ -1255,13 +1297,13 @@ TEST_F(ClusterResourceSchedulerTest, ObjectStoreMemoryUsageTest) { {"object_store_memory", 1000 * 1024 * 1024}}); int64_t used_object_store_memory = 250 * 1024 * 1024; int64_t *ptr = &used_object_store_memory; - ClusterResourceScheduler resource_scheduler(scheduling::NodeID("0"), initial_resources, - *gcs_client_, [&] { return *ptr; }); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID("0"), initial_resources, *gcs_client_, [&] { return *ptr; }); NodeResources other_node_resources; vector other_pred_capacities{1. /* CPU */, 1. /* MEM */, 1. /* GPU */}; vector other_cust_capacities{10.}; - initNodeResources(other_node_resources, other_pred_capacities, cust_ids, - other_cust_capacities); + initNodeResources( + other_node_resources, other_pred_capacities, cust_ids, other_cust_capacities); resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( scheduling::NodeID(12345), other_node_resources); @@ -1352,11 +1394,11 @@ TEST_F(ClusterResourceSchedulerTest, ObjectStoreMemoryUsageTest) { TEST_F(ClusterResourceSchedulerTest, DirtyLocalViewTest) { absl::flat_hash_map initial_resources({{"CPU", 1}}); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID("local"), - initial_resources, *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID("local"), initial_resources, *gcs_client_); auto remote = scheduling::NodeID(NodeID::FromRandom().Binary()); - resource_scheduler.GetClusterResourceManager().AddOrUpdateNode(remote, {{"CPU", 2.}}, - {{"CPU", 2.}}); + resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( + remote, {{"CPU", 2.}}, {{"CPU", 2.}}); const absl::flat_hash_map task_spec = {{"CPU", 1.}}; // Allocate local resources to force tasks onto the remote node when @@ -1385,17 +1427,18 @@ TEST_F(ClusterResourceSchedulerTest, DirtyLocalViewTest) { resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( remote, {{"CPU", 2.}}, {{"CPU", num_slots_available}}); for (int j = 0; j < num_slots_available; j++) { - ASSERT_EQ(remote, resource_scheduler.GetBestSchedulableNode( - task_spec, scheduling_strategy, false, false, true, &t, - &is_infeasible)); + ASSERT_EQ( + remote, + resource_scheduler.GetBestSchedulableNode( + task_spec, scheduling_strategy, false, false, true, &t, &is_infeasible)); // Allocate remote resources. ASSERT_TRUE(resource_scheduler.AllocateRemoteTaskResources(remote, task_spec)); } // Our local view says there are not enough resources on the remote node to // schedule another task. ASSERT_EQ( - resource_scheduler.GetBestSchedulableNode(task_spec, scheduling_strategy, false, - false, true, &t, &is_infeasible), + resource_scheduler.GetBestSchedulableNode( + task_spec, scheduling_strategy, false, false, true, &t, &is_infeasible), scheduling::NodeID::Nil()); ASSERT_FALSE( resource_scheduler.GetLocalResourceManager().AllocateLocalTaskResources( @@ -1406,8 +1449,8 @@ TEST_F(ClusterResourceSchedulerTest, DirtyLocalViewTest) { } TEST_F(ClusterResourceSchedulerTest, DynamicResourceTest) { - ClusterResourceScheduler resource_scheduler(scheduling::NodeID("local"), {{"CPU", 2}}, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID("local"), {{"CPU", 2}}, *gcs_client_); absl::flat_hash_map resource_request = {{"CPU", 1}, {"custom123", 2}}; @@ -1446,8 +1489,8 @@ TEST_F(ClusterResourceSchedulerTest, DynamicResourceTest) { } TEST_F(ClusterResourceSchedulerTest, AvailableResourceEmptyTest) { - ClusterResourceScheduler resource_scheduler(scheduling::NodeID("local"), - {{"custom123", 5}}, *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID("local"), {{"custom123", 5}}, *gcs_client_); std::shared_ptr resource_instances = std::make_shared(); absl::flat_hash_map resource_request = {{"custom123", 5}}; @@ -1461,13 +1504,13 @@ TEST_F(ClusterResourceSchedulerTest, AvailableResourceEmptyTest) { TEST_F(ClusterResourceSchedulerTest, TestForceSpillback) { absl::flat_hash_map resource_spec({{"CPU", 1}}); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID("local"), resource_spec, - *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID("local"), resource_spec, *gcs_client_); std::vector node_ids; for (int i = 0; i < 100; i++) { node_ids.emplace_back(NodeID::FromRandom().Binary()); - resource_scheduler.GetClusterResourceManager().AddOrUpdateNode(node_ids.back(), {}, - {}); + resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( + node_ids.back(), {}, {}); } // No feasible nodes. @@ -1476,28 +1519,44 @@ TEST_F(ClusterResourceSchedulerTest, TestForceSpillback) { rpc::SchedulingStrategy scheduling_strategy; scheduling_strategy.mutable_default_scheduling_strategy(); // Normally we prefer local. - ASSERT_EQ(resource_scheduler.GetBestSchedulableNode( - resource_spec, scheduling_strategy, false, false, - /*force_spillback=*/false, &total_violations, &is_infeasible), + ASSERT_EQ(resource_scheduler.GetBestSchedulableNode(resource_spec, + scheduling_strategy, + false, + false, + /*force_spillback=*/false, + &total_violations, + &is_infeasible), scheduling::NodeID("local")); // If spillback is forced, we try to spill to remote, but only if there is a // schedulable node. - ASSERT_EQ(resource_scheduler.GetBestSchedulableNode( - resource_spec, scheduling_strategy, false, false, - /*force_spillback=*/true, &total_violations, &is_infeasible), + ASSERT_EQ(resource_scheduler.GetBestSchedulableNode(resource_spec, + scheduling_strategy, + false, + false, + /*force_spillback=*/true, + &total_violations, + &is_infeasible), scheduling::NodeID::Nil()); // Choose a remote node that has the resources available. - resource_scheduler.GetClusterResourceManager().AddOrUpdateNode(node_ids[50], - resource_spec, {}); - ASSERT_EQ(resource_scheduler.GetBestSchedulableNode( - resource_spec, scheduling_strategy, false, false, - /*force_spillback=*/true, &total_violations, &is_infeasible), + resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( + node_ids[50], resource_spec, {}); + ASSERT_EQ(resource_scheduler.GetBestSchedulableNode(resource_spec, + scheduling_strategy, + false, + false, + /*force_spillback=*/true, + &total_violations, + &is_infeasible), scheduling::NodeID::Nil()); resource_scheduler.GetClusterResourceManager().AddOrUpdateNode( node_ids[51], resource_spec, resource_spec); - ASSERT_EQ(resource_scheduler.GetBestSchedulableNode( - resource_spec, scheduling_strategy, false, false, - /*force_spillback=*/true, &total_violations, &is_infeasible), + ASSERT_EQ(resource_scheduler.GetBestSchedulableNode(resource_spec, + scheduling_strategy, + false, + false, + /*force_spillback=*/true, + &total_violations, + &is_infeasible), node_ids[51]); } @@ -1508,8 +1567,8 @@ TEST_F(ClusterResourceSchedulerTest, CustomResourceInstanceTest) { "custom_unit_instance_resources": "FPGA" } )"); - ClusterResourceScheduler resource_scheduler(scheduling::NodeID("local"), - {{"CPU", 4}, {"FPGA", 2}}, *gcs_client_); + ClusterResourceScheduler resource_scheduler( + scheduling::NodeID("local"), {{"CPU", 4}, {"FPGA", 2}}, *gcs_client_); StringIdMap mock_string_to_int_map; int64_t fpga_resource_id = mock_string_to_int_map.Insert("FPGA"); diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index c06eb2aa3..f10ab2c27 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -36,19 +36,23 @@ ClusterTaskManager::ClusterTaskManager( get_node_info_(get_node_info), announce_infeasible_task_(announce_infeasible_task), local_task_manager_(std::move(local_task_manager)), - scheduler_resource_reporter_(tasks_to_schedule_, infeasible_tasks_, - *local_task_manager_), + scheduler_resource_reporter_( + tasks_to_schedule_, infeasible_tasks_, *local_task_manager_), internal_stats_(*this, *local_task_manager_), get_time_ms_(get_time_ms) {} void ClusterTaskManager::QueueAndScheduleTask( - const RayTask &task, bool grant_or_reject, bool is_selected_based_on_locality, - rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) { + const RayTask &task, + bool grant_or_reject, + bool is_selected_based_on_locality, + rpc::RequestWorkerLeaseReply *reply, + rpc::SendReplyCallback send_reply_callback) { RAY_LOG(DEBUG) << "Queuing and scheduling task " << task.GetTaskSpecification().TaskId(); auto work = std::make_shared( - task, grant_or_reject, is_selected_based_on_locality, reply, - [send_reply_callback] { send_reply_callback(Status::OK(), nullptr, nullptr); }); + task, grant_or_reject, is_selected_based_on_locality, reply, [send_reply_callback] { + send_reply_callback(Status::OK(), nullptr, nullptr); + }); const auto &scheduling_class = task.GetTaskSpecification().GetSchedulingClass(); // If the scheduling class is infeasible, just add the work to the infeasible queue // directly. @@ -78,9 +82,11 @@ void ClusterTaskManager::ScheduleAndDispatchTasks() { RAY_LOG(DEBUG) << "Scheduling pending task " << task.GetTaskSpecification().TaskId(); auto scheduling_node_id = cluster_resource_scheduler_->GetBestSchedulableNode( - task.GetTaskSpecification(), work->PrioritizeLocalNode(), + task.GetTaskSpecification(), + work->PrioritizeLocalNode(), /*exclude_local_node*/ false, - /*requires_object_store_memory*/ false, &is_infeasible); + /*requires_object_store_memory*/ false, + &is_infeasible); // There is no node that has available resources to run the request. // Move on to the next shape. @@ -130,9 +136,11 @@ void ClusterTaskManager::TryScheduleInfeasibleTask() { << task.GetTaskSpecification().TaskId(); bool is_infeasible; cluster_resource_scheduler_->GetBestSchedulableNode( - task.GetTaskSpecification(), work->PrioritizeLocalNode(), + task.GetTaskSpecification(), + work->PrioritizeLocalNode(), /*exclude_local_node*/ false, - /*requires_object_store_memory*/ false, &is_infeasible); + /*requires_object_store_memory*/ false, + &is_infeasible); // There is no node that has available resources to run the request. // Move on to the next shape. @@ -203,8 +211,8 @@ bool ClusterTaskManager::CancelTask( } } - return local_task_manager_->CancelTask(task_id, failure_type, - scheduling_failure_message); + return local_task_manager_->CancelTask( + task_id, failure_type, scheduling_failure_message); } void ClusterTaskManager::FillPendingActorInfo(rpc::GetNodeStatsReply *reply) const { @@ -218,7 +226,9 @@ void ClusterTaskManager::FillResourceUsage( } bool ClusterTaskManager::AnyPendingTasksForResourceAcquisition( - RayTask *exemplar, bool *any_pending, int *num_pending_actor_creation, + RayTask *exemplar, + bool *any_pending, + int *num_pending_actor_creation, int *num_pending_tasks) const { // We are guaranteed that these tasks are blocked waiting for resources after a // call to ScheduleAndDispatchTasks(). They may be waiting for workers as well, but diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index 6fb8ae271..731902638 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -72,7 +72,8 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { /// \param is_selected_based_on_locality : should schedule on local node if possible. /// \param reply: The reply of the lease request. /// \param send_reply_callback: The function used during dispatching. - void QueueAndScheduleTask(const RayTask &task, bool grant_or_reject, + void QueueAndScheduleTask(const RayTask &task, + bool grant_or_reject, bool is_selected_based_on_locality, rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) override; @@ -115,7 +116,8 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { /// \param[in,out] num_pending_actor_creation: Number of pending actor creation tasks. /// \param[in,out] num_pending_tasks: Number of pending tasks. /// \return True if any progress is any tasks are pending. - bool AnyPendingTasksForResourceAcquisition(RayTask *example, bool *any_pending, + bool AnyPendingTasksForResourceAcquisition(RayTask *example, + bool *any_pending, int *num_pending_actor_creation, int *num_pending_tasks) const override; diff --git a/src/ray/raylet/scheduling/cluster_task_manager_interface.h b/src/ray/raylet/scheduling/cluster_task_manager_interface.h index 37d139e76..23ea14b41 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_interface.h +++ b/src/ray/raylet/scheduling/cluster_task_manager_interface.h @@ -62,7 +62,8 @@ class ClusterTaskManagerInterface { /// but no spillback. /// \param reply: The reply of the lease request. /// \param send_reply_callback: The function used during dispatching. - virtual void QueueAndScheduleTask(const RayTask &task, bool grant_or_reject, + virtual void QueueAndScheduleTask(const RayTask &task, + bool grant_or_reject, bool is_selected_based_on_locality, rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) = 0; @@ -74,7 +75,8 @@ class ClusterTaskManagerInterface { /// \param[in] num_pending_tasks Number of pending tasks. /// \param[in] any_pending True if there's any pending exemplar. /// \return True if any progress is any tasks are pending. - virtual bool AnyPendingTasksForResourceAcquisition(RayTask *exemplar, bool *any_pending, + virtual bool AnyPendingTasksForResourceAcquisition(RayTask *exemplar, + bool *any_pending, int *num_pending_actor_creation, int *num_pending_tasks) const = 0; diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index 74a5cf65c..bbd695f52 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -48,7 +48,8 @@ class MockWorkerPool : public WorkerPoolInterface { public: MockWorkerPool() : num_pops(0) {} - void PopWorker(const TaskSpecification &task_spec, const PopWorkerCallback &callback, + void PopWorker(const TaskSpecification &task_spec, + const PopWorkerCallback &callback, const std::string &allocated_instances_serialized_json) { num_pops++; const int runtime_env_hash = task_spec.GetRuntimeEnvHash(); @@ -72,7 +73,8 @@ class MockWorkerPool : public WorkerPoolInterface { for (const auto &callback : pair.second) { // No task should be dispatched. ASSERT_FALSE( - callback(nullptr, status, + callback(nullptr, + status, /*runtime_env_setup_error_msg*/ runtime_env_setup_error_msg)); } } @@ -136,17 +138,29 @@ std::shared_ptr CreateSingleNodeScheduler( } RayTask CreateTask( - const std::unordered_map &required_resources, int num_args = 0, + const std::unordered_map &required_resources, + int num_args = 0, std::vector args = {}, const std::shared_ptr runtime_env_info = nullptr) { TaskSpecBuilder spec_builder; TaskID id = RandomTaskId(); JobID job_id = RandomJobId(); rpc::Address address; - spec_builder.SetCommonTaskSpec(id, "dummy_task", Language::PYTHON, + spec_builder.SetCommonTaskSpec(id, + "dummy_task", + Language::PYTHON, FunctionDescriptorBuilder::BuildPython("", "", "", ""), - job_id, TaskID::Nil(), 0, TaskID::Nil(), address, 0, - required_resources, {}, "", 0, runtime_env_info); + job_id, + TaskID::Nil(), + 0, + TaskID::Nil(), + address, + 0, + required_resources, + {}, + "", + 0, + runtime_env_info); if (!args.empty()) { for (auto &arg : args) { @@ -214,14 +228,16 @@ class ClusterTaskManagerTest : public ::testing::Test { ClusterTaskManagerTest(double num_cpus_at_head = 8.0, double num_gpus_at_head = 0.0) : gcs_client_(std::make_unique()), id_(NodeID::FromRandom()), - scheduler_(CreateSingleNodeScheduler(id_.Binary(), num_cpus_at_head, - num_gpus_at_head, *gcs_client_)), + scheduler_(CreateSingleNodeScheduler( + id_.Binary(), num_cpus_at_head, num_gpus_at_head, *gcs_client_)), is_owner_alive_(true), node_info_calls_(0), announce_infeasible_task_calls_(0), dependency_manager_(missing_objects_), local_task_manager_(std::make_shared( - id_, scheduler_, dependency_manager_, /* is_owner_alive= */ + id_, + scheduler_, + dependency_manager_, /* is_owner_alive= */ [this](const WorkerID &worker_id, const NodeID &node_id) { return is_owner_alive_; }, @@ -233,7 +249,8 @@ class ClusterTaskManagerTest : public ::testing::Test { } return nullptr; }, - pool_, leased_workers_, + pool_, + leased_workers_, /* get_task_arguments= */ [this](const std::vector &object_ids, std::vector> *results) { @@ -249,7 +266,8 @@ class ClusterTaskManagerTest : public ::testing::Test { /*max_pinned_task_arguments_bytes=*/1000, /*get_time=*/[this]() { return current_time_ms_; })), task_manager_( - id_, scheduler_, + id_, + scheduler_, /* get_node_info= */ [this](const NodeID &node_id) -> const rpc::GcsNodeInfo * { node_info_calls_++; @@ -278,7 +296,9 @@ class ClusterTaskManagerTest : public ::testing::Test { void Shutdown() {} - void AddNode(const NodeID &id, double num_cpus, double num_gpus = 0, + void AddNode(const NodeID &id, + double num_cpus, + double num_gpus = 0, double memory = 0) { absl::flat_hash_map node_resources; node_resources[ray::kCPU_ResourceLabel] = num_cpus; @@ -379,8 +399,8 @@ TEST_F(ClusterTaskManagerTest, BasicTest) { rpc::RequestWorkerLeaseReply reply; bool callback_occurred = false; bool *callback_occurred_ptr = &callback_occurred; - auto callback = [callback_occurred_ptr](Status, std::function, - std::function) { + auto callback = [callback_occurred_ptr]( + Status, std::function, std::function) { *callback_occurred_ptr = true; }; @@ -418,8 +438,8 @@ TEST_F(ClusterTaskManagerTest, IdempotencyTest) { rpc::RequestWorkerLeaseReply reply; bool callback_occurred = false; bool *callback_occurred_ptr = &callback_occurred; - auto callback = [callback_occurred_ptr](Status, std::function, - std::function) { + auto callback = [callback_occurred_ptr]( + Status, std::function, std::function) { *callback_occurred_ptr = true; }; @@ -482,8 +502,8 @@ TEST_F(ClusterTaskManagerTest, DispatchQueueNonBlockingTest) { rpc::RequestWorkerLeaseReply reply_A; bool callback_occurred = false; bool *callback_occurred_ptr = &callback_occurred; - auto callback = [callback_occurred_ptr](Status, std::function, - std::function) { + auto callback = [callback_occurred_ptr]( + Status, std::function, std::function) { *callback_occurred_ptr = true; }; @@ -537,8 +557,8 @@ TEST_F(ClusterTaskManagerTest, BlockedWorkerDiesTest) { rpc::RequestWorkerLeaseReply reply; bool callback_occurred = false; bool *callback_occurred_ptr = &callback_occurred; - auto callback = [callback_occurred_ptr](Status, std::function, - std::function) { + auto callback = [callback_occurred_ptr]( + Status, std::function, std::function) { *callback_occurred_ptr = true; }; @@ -582,8 +602,8 @@ TEST_F(ClusterTaskManagerTest, BlockedWorkerDies2Test) { rpc::RequestWorkerLeaseReply reply; bool callback_occurred = false; bool *callback_occurred_ptr = &callback_occurred; - auto callback = [callback_occurred_ptr](Status, std::function, - std::function) { + auto callback = [callback_occurred_ptr]( + Status, std::function, std::function) { *callback_occurred_ptr = true; }; @@ -627,8 +647,8 @@ TEST_F(ClusterTaskManagerTest, NoFeasibleNodeTest) { bool callback_called = false; bool *callback_called_ptr = &callback_called; - auto callback = [callback_called_ptr](Status, std::function, - std::function) { + auto callback = [callback_called_ptr]( + Status, std::function, std::function) { *callback_called_ptr = true; }; @@ -658,8 +678,8 @@ TEST_F(ClusterTaskManagerTest, ResourceTakenWhileResolving) { rpc::RequestWorkerLeaseReply reply; int num_callbacks = 0; int *num_callbacks_ptr = &num_callbacks; - auto callback = [num_callbacks_ptr](Status, std::function, - std::function) { + auto callback = [num_callbacks_ptr]( + Status, std::function, std::function) { (*num_callbacks_ptr) = *num_callbacks_ptr + 1; }; @@ -764,8 +784,8 @@ TEST_F(ClusterTaskManagerTest, TestIsSelectedBasedOnLocality) { ASSERT_EQ(pool_.workers.size(), 1); auto task3 = CreateTask({{ray::kCPU_ResourceLabel, 1}}); - task_manager_.QueueAndScheduleTask(task3, false, /*is_selected_based_on_locality=*/true, - &local_reply, callback); + task_manager_.QueueAndScheduleTask( + task3, false, /*is_selected_based_on_locality=*/true, &local_reply, callback); pool_.TriggerCallbacks(); ASSERT_EQ(num_callbacks, 3); // The third task was dispatched. @@ -798,8 +818,8 @@ TEST_F(ClusterTaskManagerTest, TestGrantOrReject) { auto task1 = CreateTask({{ray::kCPU_ResourceLabel, 5}}); rpc::RequestWorkerLeaseReply local_reply; - task_manager_.QueueAndScheduleTask(task1, /*grant_or_reject=*/false, false, - &local_reply, callback); + task_manager_.QueueAndScheduleTask( + task1, /*grant_or_reject=*/false, false, &local_reply, callback); pool_.TriggerCallbacks(); ASSERT_EQ(num_callbacks, 1); // The first task was dispatched. @@ -808,8 +828,8 @@ TEST_F(ClusterTaskManagerTest, TestGrantOrReject) { auto task2 = CreateTask({{ray::kCPU_ResourceLabel, 1}}); rpc::RequestWorkerLeaseReply spillback_reply; - task_manager_.QueueAndScheduleTask(task2, /*grant_or_reject=*/false, false, - &spillback_reply, callback); + task_manager_.QueueAndScheduleTask( + task2, /*grant_or_reject=*/false, false, &spillback_reply, callback); pool_.TriggerCallbacks(); // The second task was spilled. ASSERT_EQ(num_callbacks, 2); @@ -819,8 +839,8 @@ TEST_F(ClusterTaskManagerTest, TestGrantOrReject) { ASSERT_EQ(pool_.workers.size(), 1); auto task3 = CreateTask({{ray::kCPU_ResourceLabel, 1}}); - task_manager_.QueueAndScheduleTask(task3, /*grant_or_reject=*/true, false, &local_reply, - callback); + task_manager_.QueueAndScheduleTask( + task3, /*grant_or_reject=*/true, false, &local_reply, callback); pool_.TriggerCallbacks(); ASSERT_EQ(num_callbacks, 3); // The third task was dispatched. @@ -864,8 +884,8 @@ TEST_F(ClusterTaskManagerTest, TestSpillAfterAssigned) { // Resources are no longer available for the second. auto task2 = CreateTask({{ray::kCPU_ResourceLabel, 5}}); rpc::RequestWorkerLeaseReply reject_reply; - task_manager_.QueueAndScheduleTask(task2, /*grant_or_reject=*/true, false, - &reject_reply, callback); + task_manager_.QueueAndScheduleTask( + task2, /*grant_or_reject=*/true, false, &reject_reply, callback); pool_.TriggerCallbacks(); // The second task was rejected. @@ -910,8 +930,8 @@ TEST_F(ClusterTaskManagerTest, NotOKPopWorkerTest) { rpc::RequestWorkerLeaseReply reply; bool callback_called = false; bool *callback_called_ptr = &callback_called; - auto callback = [callback_called_ptr](Status, std::function, - std::function) { + auto callback = [callback_called_ptr]( + Status, std::function, std::function) { *callback_called_ptr = true; }; task_manager_.QueueAndScheduleTask(task1, false, false, &reply, callback); @@ -954,8 +974,8 @@ TEST_F(ClusterTaskManagerTest, TaskCancellationTest) { bool callback_called = false; bool *callback_called_ptr = &callback_called; - auto callback = [callback_called_ptr](Status, std::function, - std::function) { + auto callback = [callback_called_ptr]( + Status, std::function, std::function) { *callback_called_ptr = true; }; @@ -1010,8 +1030,8 @@ TEST_F(ClusterTaskManagerTest, TaskCancelInfeasibleTask) { bool callback_called = false; bool *callback_called_ptr = &callback_called; - auto callback = [callback_called_ptr](Status, std::function, - std::function) { + auto callback = [callback_called_ptr]( + Status, std::function, std::function) { *callback_called_ptr = true; }; @@ -1052,8 +1072,8 @@ TEST_F(ClusterTaskManagerTest, HeartbeatTest) { bool callback_called = false; bool *callback_called_ptr = &callback_called; - auto callback = [callback_called_ptr](Status, std::function, - std::function) { + auto callback = [callback_called_ptr]( + Status, std::function, std::function) { *callback_called_ptr = true; }; @@ -1069,8 +1089,8 @@ TEST_F(ClusterTaskManagerTest, HeartbeatTest) { bool callback_called = false; bool *callback_called_ptr = &callback_called; - auto callback = [callback_called_ptr](Status, std::function, - std::function) { + auto callback = [callback_called_ptr]( + Status, std::function, std::function) { *callback_called_ptr = true; }; @@ -1087,8 +1107,8 @@ TEST_F(ClusterTaskManagerTest, HeartbeatTest) { bool callback_called = false; bool *callback_called_ptr = &callback_called; - auto callback = [callback_called_ptr](Status, std::function, - std::function) { + auto callback = [callback_called_ptr]( + Status, std::function, std::function) { *callback_called_ptr = true; }; @@ -1105,8 +1125,8 @@ TEST_F(ClusterTaskManagerTest, HeartbeatTest) { bool callback_called = false; bool *callback_called_ptr = &callback_called; - auto callback = [callback_called_ptr](Status, std::function, - std::function) { + auto callback = [callback_called_ptr]( + Status, std::function, std::function) { *callback_called_ptr = true; }; @@ -1169,8 +1189,8 @@ TEST_F(ClusterTaskManagerTest, BacklogReportTest) { rpc::RequestWorkerLeaseReply reply; bool callback_occurred = false; bool *callback_occurred_ptr = &callback_occurred; - auto callback = [callback_occurred_ptr](Status, std::function, - std::function) { + auto callback = [callback_occurred_ptr]( + Status, std::function, std::function) { *callback_occurred_ptr = true; }; @@ -1258,8 +1278,8 @@ TEST_F(ClusterTaskManagerTest, OwnerDeadTest) { rpc::RequestWorkerLeaseReply reply; bool callback_occurred = false; bool *callback_occurred_ptr = &callback_occurred; - auto callback = [callback_occurred_ptr](Status, std::function, - std::function) { + auto callback = [callback_occurred_ptr]( + Status, std::function, std::function) { *callback_occurred_ptr = true; }; @@ -1294,8 +1314,8 @@ TEST_F(ClusterTaskManagerTest, TestInfeasibleTaskWarning) { RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 12}}); rpc::RequestWorkerLeaseReply reply; std::shared_ptr callback_occurred = std::make_shared(false); - auto callback = [callback_occurred](Status, std::function, - std::function) { + auto callback = [callback_occurred]( + Status, std::function, std::function) { *callback_occurred = true; }; task_manager_.QueueAndScheduleTask(task, false, false, &reply, callback); @@ -1341,8 +1361,8 @@ TEST_F(ClusterTaskManagerTest, TestMultipleInfeasibleTasksWarnOnce) { RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 12}}); rpc::RequestWorkerLeaseReply reply; std::shared_ptr callback_occurred = std::make_shared(false); - auto callback = [callback_occurred](Status, std::function, - std::function) { + auto callback = [callback_occurred]( + Status, std::function, std::function) { *callback_occurred = true; }; task_manager_.QueueAndScheduleTask(task, false, false, &reply, callback); @@ -1353,8 +1373,8 @@ TEST_F(ClusterTaskManagerTest, TestMultipleInfeasibleTasksWarnOnce) { RayTask task2 = CreateTask({{ray::kCPU_ResourceLabel, 12}}); rpc::RequestWorkerLeaseReply reply2; std::shared_ptr callback_occurred2 = std::make_shared(false); - auto callback2 = [callback_occurred2](Status, std::function, - std::function) { + auto callback2 = [callback_occurred2]( + Status, std::function, std::function) { *callback_occurred2 = true; }; task_manager_.QueueAndScheduleTask(task2, false, false, &reply2, callback2); @@ -1374,8 +1394,8 @@ TEST_F(ClusterTaskManagerTest, TestAnyPendingTasksForResourceAcquisition) { RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 6}}); rpc::RequestWorkerLeaseReply reply; std::shared_ptr callback_occurred = std::make_shared(false); - auto callback = [callback_occurred](Status, std::function, - std::function) { + auto callback = [callback_occurred]( + Status, std::function, std::function) { *callback_occurred = true; }; task_manager_.QueueAndScheduleTask(task, false, false, &reply, callback); @@ -1396,8 +1416,8 @@ TEST_F(ClusterTaskManagerTest, TestAnyPendingTasksForResourceAcquisition) { RayTask task2 = CreateTask({{ray::kCPU_ResourceLabel, 6}}); rpc::RequestWorkerLeaseReply reply2; std::shared_ptr callback_occurred2 = std::make_shared(false); - auto callback2 = [callback_occurred2](Status, std::function, - std::function) { + auto callback2 = [callback_occurred2]( + Status, std::function, std::function) { *callback_occurred2 = true; }; task_manager_.QueueAndScheduleTask(task2, false, false, &reply2, callback2); @@ -1419,8 +1439,8 @@ TEST_F(ClusterTaskManagerTest, ArgumentEvicted) { rpc::RequestWorkerLeaseReply reply; int num_callbacks = 0; int *num_callbacks_ptr = &num_callbacks; - auto callback = [num_callbacks_ptr](Status, std::function, - std::function) { + auto callback = [num_callbacks_ptr]( + Status, std::function, std::function) { (*num_callbacks_ptr) = *num_callbacks_ptr + 1; }; @@ -1470,7 +1490,10 @@ TEST_F(ClusterTaskManagerTest, FeasibleToNonFeasible) { rpc::RequestWorkerLeaseReply reply1; bool callback_occurred1 = false; task_manager_.QueueAndScheduleTask( - task1, false, false, &reply1, + task1, + false, + false, + &reply1, [&callback_occurred1](Status, std::function, std::function) { callback_occurred1 = true; }); @@ -1491,7 +1514,10 @@ TEST_F(ClusterTaskManagerTest, FeasibleToNonFeasible) { rpc::RequestWorkerLeaseReply reply2; bool callback_occurred2 = false; task_manager_.QueueAndScheduleTask( - task2, false, false, &reply2, + task2, + false, + false, + &reply2, [&callback_occurred2](Status, std::function, std::function) { callback_occurred2 = true; }); @@ -1658,8 +1684,8 @@ TEST_F(ClusterTaskManagerTest, PinnedArgsMemoryTest) { rpc::RequestWorkerLeaseReply reply; int num_callbacks = 0; int *num_callbacks_ptr = &num_callbacks; - auto callback = [num_callbacks_ptr](Status, std::function, - std::function) { + auto callback = [num_callbacks_ptr]( + Status, std::function, std::function) { (*num_callbacks_ptr) = *num_callbacks_ptr + 1; }; @@ -1712,8 +1738,8 @@ TEST_F(ClusterTaskManagerTest, PinnedArgsSameMemoryTest) { rpc::RequestWorkerLeaseReply reply; int num_callbacks = 0; int *num_callbacks_ptr = &num_callbacks; - auto callback = [num_callbacks_ptr](Status, std::function, - std::function) { + auto callback = [num_callbacks_ptr]( + Status, std::function, std::function) { (*num_callbacks_ptr) = *num_callbacks_ptr + 1; }; @@ -1728,8 +1754,8 @@ TEST_F(ClusterTaskManagerTest, PinnedArgsSameMemoryTest) { AssertPinnedTaskArgumentsPresent(task); // This task can run because it depends on the same object as the first task. - auto task2 = CreateTask({{ray::kCPU_ResourceLabel, 1}}, 1, - task.GetTaskSpecification().GetDependencyIds()); + auto task2 = CreateTask( + {{ray::kCPU_ResourceLabel, 1}}, 1, task.GetTaskSpecification().GetDependencyIds()); task_manager_.QueueAndScheduleTask(task2, false, false, &reply, callback); pool_.TriggerCallbacks(); ASSERT_EQ(num_callbacks, 2); @@ -1751,8 +1777,8 @@ TEST_F(ClusterTaskManagerTest, LargeArgsNoStarvationTest) { rpc::RequestWorkerLeaseReply reply; int num_callbacks = 0; int *num_callbacks_ptr = &num_callbacks; - auto callback = [num_callbacks_ptr](Status, std::function, - std::function) { + auto callback = [num_callbacks_ptr]( + Status, std::function, std::function) { (*num_callbacks_ptr) = *num_callbacks_ptr + 1; }; @@ -1796,14 +1822,14 @@ TEST_F(ClusterTaskManagerTest, PopWorkerExactlyOnce) { runtime_env_info.reset(new rpc::RuntimeEnvInfo()); runtime_env_info->set_serialized_runtime_env(serialized_runtime_env); - RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 4}}, /*num_args=*/0, /*args=*/{}, - runtime_env_info); + RayTask task = CreateTask( + {{ray::kCPU_ResourceLabel, 4}}, /*num_args=*/0, /*args=*/{}, runtime_env_info); auto runtime_env_hash = task.GetTaskSpecification().GetRuntimeEnvHash(); rpc::RequestWorkerLeaseReply reply; bool callback_occurred = false; bool *callback_occurred_ptr = &callback_occurred; - auto callback = [callback_occurred_ptr](Status, std::function, - std::function) { + auto callback = [callback_occurred_ptr]( + Status, std::function, std::function) { *callback_occurred_ptr = true; }; @@ -1844,11 +1870,14 @@ TEST_F(ClusterTaskManagerTest, CapRunningOnDispatchQueue) { scheduler_->GetLocalResourceManager().AddLocalResourceInstances( scheduling::ResourceID(ray::kGPU_ResourceLabel), {1, 1, 1}); RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 4}, {ray::kGPU_ResourceLabel, 1}}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); RayTask task2 = CreateTask({{ray::kCPU_ResourceLabel, 4}, {ray::kGPU_ResourceLabel, 1}}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); RayTask task3 = CreateTask({{ray::kCPU_ResourceLabel, 4}, {ray::kGPU_ResourceLabel, 1}}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); auto runtime_env_hash = task.GetTaskSpecification().GetRuntimeEnvHash(); std::vector> workers; for (int i = 0; i < 3; i++) { @@ -1980,7 +2009,8 @@ TEST_F(ClusterTaskManagerTest, SchedulingClassCapIncrease) { std::vector tasks; for (int i = 0; i < 3; i++) { RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 8}}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); tasks.emplace_back(task); } @@ -2045,7 +2075,8 @@ TEST_F(ClusterTaskManagerTest, SchedulingClassCapIncrease) { // Now schedule another task of the same scheduling class. RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 8}}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); task_manager_.QueueAndScheduleTask(task, false, false, &reply, callback); std::shared_ptr new_worker = @@ -2077,7 +2108,8 @@ TEST_F(ClusterTaskManagerTest, SchedulingClassCapResetTest) { std::vector tasks; for (int i = 0; i < 2; i++) { RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 8}}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); tasks.emplace_back(task); } @@ -2117,7 +2149,8 @@ TEST_F(ClusterTaskManagerTest, SchedulingClassCapResetTest) { for (int i = 0; i < 2; i++) { RayTask task = CreateTask({{ray::kCPU_ResourceLabel, 8}}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); task_manager_.QueueAndScheduleTask(task, false, false, &reply, callback); } @@ -2142,7 +2175,8 @@ TEST_F(ClusterTaskManagerTest, SchedulingClassCapResetTest) { { // Ensure a class of a differenct scheduling class can still be scheduled. RayTask task5 = CreateTask({}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); task_manager_.QueueAndScheduleTask(task5, false, false, &reply, callback); std::shared_ptr worker5 = std::make_shared(WorkerID::FromRandom(), 1234, runtime_env_hash); @@ -2164,7 +2198,8 @@ TEST_F(ClusterTaskManagerTest, SchedulingClassCapResetTest) { TEST_F(ClusterTaskManagerTest, DispatchTimerAfterRequestTest) { int64_t UNIT = RayConfig::instance().worker_cap_initial_backoff_delay_ms(); RayTask first_task = CreateTask({{ray::kCPU_ResourceLabel, 8}}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); rpc::RequestWorkerLeaseReply reply; int num_callbacks = 0; @@ -2187,7 +2222,8 @@ TEST_F(ClusterTaskManagerTest, DispatchTimerAfterRequestTest) { ASSERT_EQ(num_callbacks, 1); RayTask second_task = CreateTask({{ray::kCPU_ResourceLabel, 8}}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); task_manager_.QueueAndScheduleTask(second_task, false, false, &reply, callback); pool_.TriggerCallbacks(); @@ -2214,7 +2250,8 @@ TEST_F(ClusterTaskManagerTest, DispatchTimerAfterRequestTest) { current_time_ms_ += 100000 * UNIT; RayTask third_task = CreateTask({{ray::kCPU_ResourceLabel, 8}}, - /*num_args=*/0, /*args=*/{}); + /*num_args=*/0, + /*args=*/{}); task_manager_.QueueAndScheduleTask(third_task, false, false, &reply, callback); pool_.TriggerCallbacks(); @@ -2243,7 +2280,8 @@ TEST_F(ClusterTaskManagerTestWithoutCPUsAtHead, OneCpuInfeasibleTask) { rpc::RequestWorkerLeaseReply reply; bool callback_occurred = false; bool *callback_occurred_ptr = &callback_occurred; - auto callback = [callback_occurred_ptr](const Status &, const std::function &, + auto callback = [callback_occurred_ptr](const Status &, + const std::function &, const std::function &) { *callback_occurred_ptr = true; }; diff --git a/src/ray/raylet/scheduling/internal.h b/src/ray/raylet/scheduling/internal.h index d0d942d6a..eb7fe6180 100644 --- a/src/ray/raylet/scheduling/internal.h +++ b/src/ray/raylet/scheduling/internal.h @@ -62,8 +62,11 @@ class Work { rpc::RequestWorkerLeaseReply *reply; std::function callback; std::shared_ptr allocated_instances; - Work(RayTask task, bool grant_or_reject, bool is_selected_based_on_locality, - rpc::RequestWorkerLeaseReply *reply, std::function callback, + Work(RayTask task, + bool grant_or_reject, + bool is_selected_based_on_locality, + rpc::RequestWorkerLeaseReply *reply, + std::function callback, WorkStatus status = WorkStatus::WAITING) : task(task), grant_or_reject(grant_or_reject), diff --git a/src/ray/raylet/scheduling/local_resource_manager.cc b/src/ray/raylet/scheduling/local_resource_manager.cc index 29a3a791b..a82b4da22 100644 --- a/src/ray/raylet/scheduling/local_resource_manager.cc +++ b/src/ray/raylet/scheduling/local_resource_manager.cc @@ -22,7 +22,8 @@ namespace ray { LocalResourceManager::LocalResourceManager( - scheduling::NodeID local_node_id, const NodeResources &node_resources, + scheduling::NodeID local_node_id, + const NodeResources &node_resources, std::function get_used_object_store_memory, std::function get_pull_manager_at_capacity, std::function resource_change_subscriber) @@ -160,7 +161,8 @@ void LocalResourceManager::InitLocalResources(const NodeResources &node_resource bool is_unit_instance = predefined_unit_instance_resources_.find(i) != predefined_unit_instance_resources_.end(); InitResourceInstances(node_resources.predefined_resources[i].total, - is_unit_instance, &local_resources_.predefined_resources[i]); + is_unit_instance, + &local_resources_.predefined_resources[i]); } } @@ -169,7 +171,8 @@ void LocalResourceManager::InitLocalResources(const NodeResources &node_resource } for (auto it = node_resources.custom_resources.begin(); - it != node_resources.custom_resources.end(); ++it) { + it != node_resources.custom_resources.end(); + ++it) { if (it->second.total > 0) { bool is_unit_instance = custom_unit_instance_resources_.find(it->first) != custom_unit_instance_resources_.end(); @@ -196,7 +199,8 @@ std::vector LocalResourceManager::AddAvailableResourceInstances( } std::vector LocalResourceManager::SubtractAvailableResourceInstances( - std::vector available, ResourceInstanceCapacities *resource_instances, + std::vector available, + ResourceInstanceCapacities *resource_instances, bool allow_going_negative) const { RAY_CHECK(available.size() == resource_instances->available.size()); @@ -221,7 +225,8 @@ std::vector LocalResourceManager::SubtractAvailableResourceInstances } bool LocalResourceManager::AllocateResourceInstances( - FixedPoint demand, std::vector &available, + FixedPoint demand, + std::vector &available, std::vector *allocation) const { allocation->resize(available.size()); FixedPoint remaining_demand = demand; @@ -314,8 +319,8 @@ bool LocalResourceManager::AllocateTaskResourceInstances( if (it != local_resources_.custom_resources.end()) { if (task_req_custom_resource.second > 0) { std::vector allocation; - bool success = AllocateResourceInstances(task_req_custom_resource.second, - it->second.available, &allocation); + bool success = AllocateResourceInstances( + task_req_custom_resource.second, it->second.available, &allocation); // Even if allocation failed we need to remember partial allocations to correctly // free resources. task_allocation->custom_resources.emplace(it->first, allocation); @@ -371,7 +376,8 @@ std::vector LocalResourceManager::AddResourceInstances( } std::vector LocalResourceManager::SubtractResourceInstances( - scheduling::ResourceID resource_id, const std::vector &resource_instances, + scheduling::ResourceID resource_id, + const std::vector &resource_instances, bool allow_going_negative) { std::vector resource_instances_fp = VectorDoubleToVectorFixedPoint(resource_instances); @@ -380,9 +386,10 @@ std::vector LocalResourceManager::SubtractResourceInstances( return resource_instances; // No underflow. } - auto underflow = SubtractAvailableResourceInstances( - resource_instances_fp, &local_resources_.GetMutable(resource_id), - allow_going_negative); + auto underflow = + SubtractAvailableResourceInstances(resource_instances_fp, + &local_resources_.GetMutable(resource_id), + allow_going_negative); OnResourceChanged(); return VectorFixedPointToVectorDouble(underflow); @@ -437,8 +444,8 @@ NodeResources ToNodeResources(const NodeResourceInstances &instance) { int64_t resource_name = custom_resource.first; auto &instances = custom_resource.second; - FixedPoint available = std::accumulate(instances.available.begin(), - instances.available.end(), FixedPoint()); + FixedPoint available = std::accumulate( + instances.available.begin(), instances.available.end(), FixedPoint()); FixedPoint total = std::accumulate(instances.total.begin(), instances.total.end(), FixedPoint()); diff --git a/src/ray/raylet/scheduling/local_resource_manager.h b/src/ray/raylet/scheduling/local_resource_manager.h index c50c21a86..be1715841 100644 --- a/src/ray/raylet/scheduling/local_resource_manager.h +++ b/src/ray/raylet/scheduling/local_resource_manager.h @@ -40,7 +40,8 @@ namespace ray { class LocalResourceManager { public: LocalResourceManager( - scheduling::NodeID local_node_id, const NodeResources &node_resources, + scheduling::NodeID local_node_id, + const NodeResources &node_resources, std::function get_used_object_store_memory, std::function get_pull_manager_at_capacity, std::function resource_change_subscriber); @@ -167,7 +168,8 @@ class LocalResourceManager { /// \param unit_instances: If true, we split the resource in unit-size instances. /// If false, we create a single instance of capacity "total". /// \param instance_list: The list of capacities this resource instances. - void InitResourceInstances(FixedPoint total, bool unit_instances, + void InitResourceInstances(FixedPoint total, + bool unit_instances, ResourceInstanceCapacities *instance_list); /// Init the information about which resources are unit_instance. @@ -197,7 +199,8 @@ class LocalResourceManager { /// capacities in "available", i.e.,. /// max(available - reasource_instances.available, 0) std::vector SubtractAvailableResourceInstances( - std::vector available, ResourceInstanceCapacities *resource_instances, + std::vector available, + ResourceInstanceCapacities *resource_instances, bool allow_going_negative = false) const; /// Allocate enough capacity across the instances of a resource to satisfy "demand". @@ -231,7 +234,8 @@ class LocalResourceManager { /// \return true, if allocation successful. In this case, the sum of the elements in /// "allocation" is equal to "demand". - bool AllocateResourceInstances(FixedPoint demand, std::vector &available, + bool AllocateResourceInstances(FixedPoint demand, + std::vector &available, std::vector *allocation) const; /// Allocate local resources to satisfy a given request (resource_request). diff --git a/src/ray/raylet/scheduling/local_task_manager.cc b/src/ray/raylet/scheduling/local_task_manager.cc index 2a3af4a0c..622095793 100644 --- a/src/ray/raylet/scheduling/local_task_manager.cc +++ b/src/ray/raylet/scheduling/local_task_manager.cc @@ -29,12 +29,14 @@ LocalTaskManager::LocalTaskManager( std::shared_ptr cluster_resource_scheduler, TaskDependencyManagerInterface &task_dependency_manager, std::function is_owner_alive, - internal::NodeInfoGetter get_node_info, WorkerPoolInterface &worker_pool, + internal::NodeInfoGetter get_node_info, + WorkerPoolInterface &worker_pool, absl::flat_hash_map> &leased_workers, std::function &object_ids, std::vector> *results)> get_task_arguments, - size_t max_pinned_task_arguments_bytes, std::function get_time_ms, + size_t max_pinned_task_arguments_bytes, + std::function get_time_ms, int64_t sched_cls_cap_interval_ms) : self_node_id_(self_node_id), cluster_resource_scheduler_(cluster_resource_scheduler), @@ -255,10 +257,16 @@ void LocalTaskManager::DispatchScheduledTasksToWorkers() { worker_pool_.PopWorker( spec, [this, task_id, scheduling_class, work, is_detached_actor, owner_address]( - const std::shared_ptr worker, PopWorkerStatus status, + const std::shared_ptr worker, + PopWorkerStatus status, const std::string &runtime_env_setup_error_message) -> bool { - return PoppedWorkerHandler(worker, status, task_id, scheduling_class, work, - is_detached_actor, owner_address, + return PoppedWorkerHandler(worker, + status, + task_id, + scheduling_class, + work, + is_detached_actor, + owner_address, runtime_env_setup_error_message); }, allocated_instances_serialized_json); @@ -319,7 +327,8 @@ void LocalTaskManager::SpillWaitingTasks() { (*it)->task.GetTaskSpecification(), /*prioritize_local_node*/ true, /*exclude_local_node*/ force_spillback, - /*requires_object_store_memory*/ true, &is_infeasible); + /*requires_object_store_memory*/ true, + &is_infeasible); if (!scheduling_node_id.IsNil() && scheduling_node_id.Binary() != self_node_id_.Binary()) { NodeID node_id = NodeID::FromBinary(scheduling_node_id.Binary()); @@ -349,9 +358,11 @@ void LocalTaskManager::SpillWaitingTasks() { bool LocalTaskManager::TrySpillback(const std::shared_ptr &work, bool &is_infeasible) { auto scheduling_node_id = cluster_resource_scheduler_->GetBestSchedulableNode( - work->task.GetTaskSpecification(), work->PrioritizeLocalNode(), + work->task.GetTaskSpecification(), + work->PrioritizeLocalNode(), /*exclude_local_node*/ false, - /*requires_object_store_memory*/ false, &is_infeasible); + /*requires_object_store_memory*/ false, + &is_infeasible); if (is_infeasible || scheduling_node_id.IsNil() || scheduling_node_id.Binary() == self_node_id_.Binary()) { @@ -364,9 +375,12 @@ bool LocalTaskManager::TrySpillback(const std::shared_ptr &work, } bool LocalTaskManager::PoppedWorkerHandler( - const std::shared_ptr worker, PopWorkerStatus status, - const TaskID &task_id, SchedulingClass scheduling_class, - const std::shared_ptr &work, bool is_detached_actor, + const std::shared_ptr worker, + PopWorkerStatus status, + const TaskID &task_id, + SchedulingClass scheduling_class, + const std::shared_ptr &work, + bool is_detached_actor, const rpc::Address &owner_address, const std::string &runtime_env_setup_error_message) { const auto &reply = work->reply; @@ -754,7 +768,9 @@ bool LocalTaskManager::CancelTask( } bool LocalTaskManager::AnyPendingTasksForResourceAcquisition( - RayTask *exemplar, bool *any_pending, int *num_pending_actor_creation, + RayTask *exemplar, + bool *any_pending, + int *num_pending_actor_creation, int *num_pending_tasks) const { // We are guaranteed that these tasks are blocked waiting for resources after a // call to ScheduleAndDispatchTasks(). They may be waiting for workers as well, but @@ -802,7 +818,8 @@ void LocalTaskManager::Dispatch( std::shared_ptr worker, absl::flat_hash_map> &leased_workers, const std::shared_ptr &allocated_instances, - const RayTask &task, rpc::RequestWorkerLeaseReply *reply, + const RayTask &task, + rpc::RequestWorkerLeaseReply *reply, std::function send_reply_callback) { const auto &task_spec = task.GetTaskSpecification(); @@ -888,7 +905,8 @@ void LocalTaskManager::ClearWorkerBacklog(const WorkerID &worker_id) { } void LocalTaskManager::SetWorkerBacklog(SchedulingClass scheduling_class, - const WorkerID &worker_id, int64_t backlog_size) { + const WorkerID &worker_id, + int64_t backlog_size) { if (backlog_size == 0) { backlog_tracker_[scheduling_class].erase(worker_id); if (backlog_tracker_[scheduling_class].empty()) { diff --git a/src/ray/raylet/scheduling/local_task_manager.h b/src/ray/raylet/scheduling/local_task_manager.h index 1b676dc5f..8559523bd 100644 --- a/src/ray/raylet/scheduling/local_task_manager.h +++ b/src/ray/raylet/scheduling/local_task_manager.h @@ -78,7 +78,8 @@ class LocalTaskManager { std::shared_ptr cluster_resource_scheduler, TaskDependencyManagerInterface &task_dependency_manager, std::function is_owner_alive, - internal::NodeInfoGetter get_node_info, WorkerPoolInterface &worker_pool, + internal::NodeInfoGetter get_node_info, + WorkerPoolInterface &worker_pool, absl::flat_hash_map> &leased_workers, std::function &object_ids, std::vector> *results)> @@ -128,7 +129,8 @@ class LocalTaskManager { /// \param[in,out] num_pending_actor_creation: Number of pending actor creation tasks. /// \param[in,out] num_pending_tasks: Number of pending tasks. /// \return True if any progress is any tasks are pending. - bool AnyPendingTasksForResourceAcquisition(RayTask *example, bool *any_pending, + bool AnyPendingTasksForResourceAcquisition(RayTask *example, + bool *any_pending, int *num_pending_actor_creation, int *num_pending_tasks) const; @@ -159,7 +161,8 @@ class LocalTaskManager { /// Calculate normal task resources. ResourceSet CalcNormalTaskResources() const; - void SetWorkerBacklog(SchedulingClass scheduling_class, const WorkerID &worker_id, + void SetWorkerBacklog(SchedulingClass scheduling_class, + const WorkerID &worker_id, int64_t backlog_size); void ClearWorkerBacklog(const WorkerID &worker_id); @@ -171,10 +174,12 @@ class LocalTaskManager { /// Handle the popped worker from worker pool. bool PoppedWorkerHandler(const std::shared_ptr worker, - PopWorkerStatus status, const TaskID &task_id, + PopWorkerStatus status, + const TaskID &task_id, SchedulingClass scheduling_class, const std::shared_ptr &work, - bool is_detached_actor, const rpc::Address &owner_address, + bool is_detached_actor, + const rpc::Address &owner_address, const std::string &runtime_env_setup_error_message); /// Attempts to dispatch all tasks which are ready to run. A task @@ -222,7 +227,8 @@ class LocalTaskManager { std::shared_ptr worker, absl::flat_hash_map> &leased_workers_, const std::shared_ptr &allocated_instances, - const RayTask &task, rpc::RequestWorkerLeaseReply *reply, + const RayTask &task, + rpc::RequestWorkerLeaseReply *reply, std::function send_reply_callback); void Spillback(const NodeID &spillback_to, const std::shared_ptr &work); diff --git a/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.cc b/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.cc index 8d72c46d1..369644154 100644 --- a/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.cc +++ b/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.cc @@ -24,8 +24,11 @@ namespace ray { namespace raylet_scheduling_policy { scheduling::NodeID HybridSchedulingPolicy::HybridPolicyWithFilter( - const ResourceRequest &resource_request, float spread_threshold, bool force_spillback, - bool require_node_available, NodeFilter node_filter) { + const ResourceRequest &resource_request, + float spread_threshold, + bool force_spillback, + bool require_node_available, + NodeFilter node_filter) { // Step 1: Generate the traversal order. We guarantee that the first node is local, to // encourage local scheduling. The rest of the traversal order should be globally // consistent, to encourage using "warm" workers. @@ -138,23 +141,28 @@ scheduling::NodeID HybridSchedulingPolicy::Schedule( RAY_CHECK(options.scheduling_type == SchedulingType::HYBRID) << "HybridPolicy policy requires type = HYBRID"; if (!options.avoid_gpu_nodes || resource_request.IsGPURequest()) { - return HybridPolicyWithFilter(resource_request, options.spread_threshold, + return HybridPolicyWithFilter(resource_request, + options.spread_threshold, options.avoid_local_node, options.require_node_available); } // Try schedule on non-GPU nodes. - auto best_node_id = HybridPolicyWithFilter( - resource_request, options.spread_threshold, options.avoid_local_node, - /*require_node_available*/ true, NodeFilter::kNonGpu); + auto best_node_id = HybridPolicyWithFilter(resource_request, + options.spread_threshold, + options.avoid_local_node, + /*require_node_available*/ true, + NodeFilter::kNonGpu); if (!best_node_id.IsNil()) { return best_node_id; } // If we cannot find any available node from non-gpu nodes, fallback to the original // scheduling - return HybridPolicyWithFilter(resource_request, options.spread_threshold, - options.avoid_local_node, options.require_node_available); + return HybridPolicyWithFilter(resource_request, + options.spread_threshold, + options.avoid_local_node, + options.require_node_available); } } // namespace raylet_scheduling_policy diff --git a/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.h b/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.h index 665573609..4dd3a0558 100644 --- a/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.h +++ b/src/ray/raylet/scheduling/policy/hybrid_scheduling_policy.h @@ -86,7 +86,8 @@ class HybridSchedulingPolicy : public ISchedulingPolicy { /// \return -1 if the task is unfeasible, otherwise the node id (key in `nodes`) to /// schedule on. scheduling::NodeID HybridPolicyWithFilter(const ResourceRequest &resource_request, - float spread_threshold, bool force_spillback, + float spread_threshold, + bool force_spillback, bool require_available, NodeFilter node_filter = NodeFilter::kAny); }; diff --git a/src/ray/raylet/scheduling/policy/scheduling_options.h b/src/ray/raylet/scheduling/policy/scheduling_options.h index 9950470e6..abbdaf576 100644 --- a/src/ray/raylet/scheduling/policy/scheduling_options.h +++ b/src/ray/raylet/scheduling/policy/scheduling_options.h @@ -34,7 +34,8 @@ enum class SchedulingType { struct SchedulingOptions { static SchedulingOptions Random() { return SchedulingOptions(SchedulingType::RANDOM, - /*spread_threshold*/ 0, /*avoid_local_node*/ false, + /*spread_threshold*/ 0, + /*avoid_local_node*/ false, /*require_node_available*/ true, /*avoid_gpu_nodes*/ false); } @@ -42,7 +43,8 @@ struct SchedulingOptions { // construct option for spread scheduling policy. static SchedulingOptions Spread(bool avoid_local_node, bool require_node_available) { return SchedulingOptions(SchedulingType::SPREAD, - /*spread_threshold*/ 0, avoid_local_node, + /*spread_threshold*/ 0, + avoid_local_node, require_node_available, RayConfig::instance().scheduler_avoid_gpu_nodes()); } @@ -51,7 +53,8 @@ struct SchedulingOptions { static SchedulingOptions Hybrid(bool avoid_local_node, bool require_node_available) { return SchedulingOptions(SchedulingType::HYBRID, RayConfig::instance().scheduler_spread_threshold(), - avoid_local_node, require_node_available, + avoid_local_node, + require_node_available, RayConfig::instance().scheduler_avoid_gpu_nodes()); } @@ -62,8 +65,11 @@ struct SchedulingOptions { bool avoid_gpu_nodes; private: - SchedulingOptions(SchedulingType type, float spread_threshold, bool avoid_local_node, - bool require_node_available, bool avoid_gpu_nodes) + SchedulingOptions(SchedulingType type, + float spread_threshold, + bool avoid_local_node, + bool require_node_available, + bool avoid_gpu_nodes) : scheduling_type(type), spread_threshold(spread_threshold), avoid_local_node(avoid_local_node), diff --git a/src/ray/raylet/scheduling/policy/scheduling_policy_test.cc b/src/ray/raylet/scheduling/policy/scheduling_policy_test.cc index 1021c1720..954bf5793 100644 --- a/src/ray/raylet/scheduling/policy/scheduling_policy_test.cc +++ b/src/ray/raylet/scheduling/policy/scheduling_policy_test.cc @@ -23,9 +23,12 @@ namespace raylet { using ::testing::_; using namespace ray::raylet_scheduling_policy; -NodeResources CreateNodeResources(double available_cpu, double total_cpu, - double available_memory, double total_memory, - double available_gpu, double total_gpu) { +NodeResources CreateNodeResources(double available_cpu, + double total_cpu, + double available_memory, + double total_memory, + double available_gpu, + double total_gpu) { NodeResources resources; resources.predefined_resources = {{available_cpu, total_cpu}, {available_memory, total_memory}, @@ -42,11 +45,15 @@ class SchedulingPolicyTest : public ::testing::Test { absl::flat_hash_map nodes; SchedulingOptions HybridOptions( - float spread, bool avoid_local_node, bool require_node_available, + float spread, + bool avoid_local_node, + bool require_node_available, bool avoid_gpu_nodes = RayConfig::instance().scheduler_avoid_gpu_nodes()) { return SchedulingOptions(SchedulingType::HYBRID, RayConfig::instance().scheduler_spread_threshold(), - avoid_local_node, require_node_available, avoid_gpu_nodes); + avoid_local_node, + require_node_available, + avoid_gpu_nodes); } }; @@ -300,8 +307,8 @@ TEST_F(SchedulingPolicyTest, AvoidSchedulingCPURequestsOnGPUNodes) { // we should schedule on remote node. const ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}}, false); const auto to_schedule = - raylet_scheduling_policy::CompositeSchedulingPolicy(local_node, nodes, - [](auto) { return true; }) + raylet_scheduling_policy::CompositeSchedulingPolicy( + local_node, nodes, [](auto) { return true; }) .Schedule(ResourceMapToResourceRequest({{"CPU", 1}}, false), HybridOptions(0.51, false, true, true)); ASSERT_EQ(to_schedule, remote_node); @@ -392,29 +399,41 @@ TEST_F(SchedulingPolicyTest, NonGpuNodePreferredSchedulingTest) { ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}}, false); auto to_schedule = raylet_scheduling_policy::CompositeSchedulingPolicy( local_node, nodes, [](auto) { return true; }) - .Schedule(req, HybridOptions(0.51, false, true, - /*gpu_avoid_scheduling*/ true)); + .Schedule(req, + HybridOptions(0.51, + false, + true, + /*gpu_avoid_scheduling*/ true)); ASSERT_EQ(to_schedule, remote_node); req = ResourceMapToResourceRequest({{"CPU", 3}}, false); to_schedule = raylet_scheduling_policy::CompositeSchedulingPolicy( local_node, nodes, [](auto) { return true; }) - .Schedule(req, HybridOptions(0.51, false, true, - /*gpu_avoid_scheduling*/ true)); + .Schedule(req, + HybridOptions(0.51, + false, + true, + /*gpu_avoid_scheduling*/ true)); ASSERT_EQ(to_schedule, remote_node_2); req = ResourceMapToResourceRequest({{"CPU", 1}, {"GPU", 1}}, false); to_schedule = raylet_scheduling_policy::CompositeSchedulingPolicy( local_node, nodes, [](auto) { return true; }) - .Schedule(req, HybridOptions(0.51, false, true, - /*gpu_avoid_scheduling*/ true)); + .Schedule(req, + HybridOptions(0.51, + false, + true, + /*gpu_avoid_scheduling*/ true)); ASSERT_EQ(to_schedule, local_node); req = ResourceMapToResourceRequest({{"CPU", 2}}, false); to_schedule = raylet_scheduling_policy::CompositeSchedulingPolicy( local_node, nodes, [](auto) { return true; }) - .Schedule(req, HybridOptions(0.51, false, true, - /*gpu_avoid_scheduling*/ true)); + .Schedule(req, + HybridOptions(0.51, + false, + true, + /*gpu_avoid_scheduling*/ true)); ASSERT_EQ(to_schedule, remote_node); } diff --git a/src/ray/raylet/scheduling/scheduler_resource_reporter.cc b/src/ray/raylet/scheduling/scheduler_resource_reporter.cc index 715c9f04b..38ad9072a 100644 --- a/src/ray/raylet/scheduling/scheduler_resource_reporter.cc +++ b/src/ray/raylet/scheduling/scheduler_resource_reporter.cc @@ -28,10 +28,12 @@ const int kMaxPendingActorsToReport = 20; }; // namespace SchedulerResourceReporter::SchedulerResourceReporter( - const absl::flat_hash_map< - SchedulingClass, std::deque>> &tasks_to_schedule, - const absl::flat_hash_map< - SchedulingClass, std::deque>> &infeasible_tasks, + const absl::flat_hash_map>> + &tasks_to_schedule, + const absl::flat_hash_map>> + &infeasible_tasks, const LocalTaskManager &local_task_manager) : max_resource_shapes_per_load_report_( RayConfig::instance().max_resource_shapes_per_load_report()), diff --git a/src/ray/raylet/scheduling/scheduler_resource_reporter.h b/src/ray/raylet/scheduling/scheduler_resource_reporter.h index cc20c99dd..edf1fb63c 100644 --- a/src/ray/raylet/scheduling/scheduler_resource_reporter.h +++ b/src/ray/raylet/scheduling/scheduler_resource_reporter.h @@ -30,8 +30,9 @@ class SchedulerResourceReporter { const absl::flat_hash_map>> &tasks_to_schedule, - const absl::flat_hash_map< - SchedulingClass, std::deque>> &infeasible_tasks, + const absl::flat_hash_map>> + &infeasible_tasks, const LocalTaskManager &local_task_manager); /// Populate the relevant parts of the heartbeat table. This is intended for diff --git a/src/ray/raylet/scheduling/scheduler_stats.cc b/src/ray/raylet/scheduling/scheduler_stats.cc index dd6f55ab1..c594535a9 100644 --- a/src/ray/raylet/scheduling/scheduler_stats.cc +++ b/src/ray/raylet/scheduling/scheduler_stats.cc @@ -38,22 +38,27 @@ void SchedulerStats::ComputeStats() { size_t num_tasks_waiting_for_workers = 0; size_t num_cancelled_tasks = 0; - size_t num_infeasible_tasks = std::accumulate( - cluster_task_manager_.infeasible_tasks_.begin(), - cluster_task_manager_.infeasible_tasks_.end(), (size_t)0, accumulator); + size_t num_infeasible_tasks = + std::accumulate(cluster_task_manager_.infeasible_tasks_.begin(), + cluster_task_manager_.infeasible_tasks_.end(), + (size_t)0, + accumulator); // TODO(sang): Normally, the # of queued tasks are not large, so this is less likley to // be an issue that we iterate all of them. But if it uses lots of CPU, consider // optimizing by updating live instead of iterating through here. - auto per_work_accumulator = [&num_waiting_for_resource, &num_waiting_for_plasma_memory, + auto per_work_accumulator = [&num_waiting_for_resource, + &num_waiting_for_plasma_memory, &num_waiting_for_remote_node_resources, &num_worker_not_started_by_job_config_not_exist, &num_worker_not_started_by_registration_timeout, &num_worker_not_started_by_process_rate_limit, - &num_tasks_waiting_for_workers, &num_cancelled_tasks]( + &num_tasks_waiting_for_workers, + &num_cancelled_tasks]( size_t state, const std::pair< - int, std::deque>> + int, + std::deque>> &pair) { const auto &work_queue = pair.second; for (auto work_it = work_queue.begin(); work_it != work_queue.end();) { @@ -84,12 +89,16 @@ void SchedulerStats::ComputeStats() { } return state + pair.second.size(); }; - size_t num_tasks_to_schedule = std::accumulate( - cluster_task_manager_.tasks_to_schedule_.begin(), - cluster_task_manager_.tasks_to_schedule_.end(), (size_t)0, per_work_accumulator); - size_t num_tasks_to_dispatch = std::accumulate( - local_task_manager_.tasks_to_dispatch_.begin(), - local_task_manager_.tasks_to_dispatch_.end(), (size_t)0, per_work_accumulator); + size_t num_tasks_to_schedule = + std::accumulate(cluster_task_manager_.tasks_to_schedule_.begin(), + cluster_task_manager_.tasks_to_schedule_.end(), + (size_t)0, + per_work_accumulator); + size_t num_tasks_to_dispatch = + std::accumulate(local_task_manager_.tasks_to_dispatch_.begin(), + local_task_manager_.tasks_to_dispatch_.end(), + (size_t)0, + per_work_accumulator); /// Update the internal states. num_waiting_for_resource_ = num_waiting_for_resource; diff --git a/src/ray/raylet/scheduling/scheduling_policy.cc b/src/ray/raylet/scheduling/scheduling_policy.cc new file mode 100644 index 000000000..75d1329ed --- /dev/null +++ b/src/ray/raylet/scheduling/scheduling_policy.cc @@ -0,0 +1,257 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/raylet/scheduling/scheduling_policy.h" + +#include + +#include "ray/util/container_util.h" + +namespace ray { + +namespace raylet_scheduling_policy { +namespace { + +bool IsGPURequest(const ResourceRequest &resource_request) { + if (resource_request.predefined_resources.size() <= GPU) { + return false; + } + return resource_request.predefined_resources[GPU] > 0; +} + +bool DoesNodeHaveGPUs(const NodeResources &resources) { + if (resources.predefined_resources.size() <= GPU) { + return false; + } + return resources.predefined_resources[GPU].total > 0; +} +} // namespace + +scheduling::NodeID SchedulingPolicy::SpreadPolicy( + const ResourceRequest &resource_request, + bool force_spillback, + bool require_available, + std::function is_node_available) { + std::vector round; + round.reserve(nodes_.size()); + for (const auto &pair : nodes_) { + round.emplace_back(pair.first); + } + std::sort(round.begin(), round.end()); + + size_t round_index = spread_scheduling_next_index_; + for (size_t i = 0; i < round.size(); ++i, ++round_index) { + const auto &node_id = round[round_index % round.size()]; + const auto &node = map_find_or_die(nodes_, node_id); + if (node_id == local_node_id_ && force_spillback) { + continue; + } + if (!is_node_available(node_id) || + !node.GetLocalView().IsFeasible(resource_request) || + !node.GetLocalView().IsAvailable(resource_request, true)) { + continue; + } + + spread_scheduling_next_index_ = ((round_index + 1) % round.size()); + return node_id; + } + + return HybridPolicy( + resource_request, 0, force_spillback, require_available, is_node_available); +} + +scheduling::NodeID SchedulingPolicy::HybridPolicyWithFilter( + const ResourceRequest &resource_request, + float spread_threshold, + bool force_spillback, + bool require_available, + std::function is_node_available, + NodeFilter node_filter) { + // Step 1: Generate the traversal order. We guarantee that the first node is local, to + // encourage local scheduling. The rest of the traversal order should be globally + // consistent, to encourage using "warm" workers. + std::vector round; + round.reserve(nodes_.size()); + const auto local_it = nodes_.find(local_node_id_); + RAY_CHECK(local_it != nodes_.end()); + auto predicate = [node_filter, &is_node_available]( + scheduling::NodeID node_id, const NodeResources &node_resources) { + if (!is_node_available(node_id)) { + return false; + } + if (node_filter == NodeFilter::kAny) { + return true; + } + const bool has_gpu = DoesNodeHaveGPUs(node_resources); + if (node_filter == NodeFilter::kGPU) { + return has_gpu; + } + RAY_CHECK(node_filter == NodeFilter::kNonGpu); + return !has_gpu; + }; + + const auto &local_node_view = local_it->second.GetLocalView(); + // If we should include local node at all, make sure it is at the front of the list + // so that + // 1. It's first in traversal order. + // 2. It's easy to avoid sorting it. + if (predicate(local_node_id_, local_node_view) && !force_spillback) { + round.push_back(local_node_id_); + } + + const auto start_index = round.size(); + for (const auto &pair : nodes_) { + if (pair.first != local_node_id_ && + predicate(pair.first, pair.second.GetLocalView())) { + round.push_back(pair.first); + } + } + // Sort all the nodes, making sure that if we added the local node in front, it stays in + // place. + std::sort(round.begin() + start_index, round.end()); + + scheduling::NodeID best_node_id = scheduling::NodeID::Nil(); + float best_utilization_score = INFINITY; + bool best_is_available = false; + + // Step 2: Perform the round robin. + auto round_it = round.begin(); + for (; round_it != round.end(); round_it++) { + const auto &node_id = *round_it; + const auto &it = nodes_.find(node_id); + RAY_CHECK(it != nodes_.end()); + const auto &node = it->second; + if (!node.GetLocalView().IsFeasible(resource_request)) { + continue; + } + + bool ignore_pull_manager_at_capacity = false; + if (node_id == local_node_id_) { + // It's okay if the local node's pull manager is at + // capacity because we will eventually spill the task + // back from the waiting queue if its args cannot be + // pulled. + ignore_pull_manager_at_capacity = true; + } + bool is_available = node.GetLocalView().IsAvailable(resource_request, + ignore_pull_manager_at_capacity); + float critical_resource_utilization = + node.GetLocalView().CalculateCriticalResourceUtilization(); + RAY_LOG(DEBUG) << "Node " << node_id.ToInt() << " is " + << (is_available ? "available" : "not available") << " for request " + << resource_request.DebugString() + << " with critical resource utilization " + << critical_resource_utilization << " based on local view " + << node.GetLocalView().DebugString(); + if (critical_resource_utilization < spread_threshold) { + critical_resource_utilization = 0; + } + + bool update_best_node = false; + + if (is_available) { + // Always prioritize available nodes over nodes where the task must be queued first. + if (!best_is_available) { + update_best_node = true; + } else if (critical_resource_utilization < best_utilization_score) { + // Break ties between available nodes by their critical resource utilization. + update_best_node = true; + } + } else if (!best_is_available && + critical_resource_utilization < best_utilization_score && + !require_available) { + // Pick the best feasible node by critical resource utilization. + update_best_node = true; + } + + if (update_best_node) { + best_node_id = node_id; + best_utilization_score = critical_resource_utilization; + best_is_available = is_available; + } + } + + return best_node_id; +} + +scheduling::NodeID SchedulingPolicy::HybridPolicy( + const ResourceRequest &resource_request, + float spread_threshold, + bool force_spillback, + bool require_available, + std::function is_node_available, + bool scheduler_avoid_gpu_nodes) { + if (!scheduler_avoid_gpu_nodes || IsGPURequest(resource_request)) { + return HybridPolicyWithFilter(resource_request, + spread_threshold, + force_spillback, + require_available, + std::move(is_node_available)); + } + + // Try schedule on non-GPU nodes. + auto best_node_id = HybridPolicyWithFilter(resource_request, + spread_threshold, + force_spillback, + /*require_available*/ true, + is_node_available, + NodeFilter::kNonGpu); + if (!best_node_id.IsNil()) { + return best_node_id; + } + + // If we cannot find any available node from non-gpu nodes, fallback to the original + // scheduling + return HybridPolicyWithFilter(resource_request, + spread_threshold, + force_spillback, + require_available, + is_node_available); +} + +scheduling::NodeID SchedulingPolicy::RandomPolicy( + const ResourceRequest &resource_request, + std::function is_node_available) { + scheduling::NodeID best_node = scheduling::NodeID::Nil(); + if (nodes_.empty()) { + return best_node; + } + + std::uniform_int_distribution distribution(0, nodes_.size() - 1); + int idx = distribution(gen_); + auto iter = std::next(nodes_.begin(), idx); + for (size_t i = 0; i < nodes_.size(); ++i) { + // TODO(scv119): if there are a lot of nodes died or can't fulfill the resource + // requirement, the distribution might not be even. + const auto &node_id = iter->first; + const auto &node = iter->second; + if (is_node_available(node_id) && node.GetLocalView().IsFeasible(resource_request) && + node.GetLocalView().IsAvailable(resource_request, + /*ignore_pull_manager_at_capacity*/ true)) { + best_node = iter->first; + break; + } + ++iter; + if (iter == nodes_.end()) { + iter = nodes_.begin(); + } + } + RAY_LOG(DEBUG) << "RandomPolicy, best_node = " << best_node.ToInt() + << ", # nodes = " << nodes_.size() + << ", resource_request = " << resource_request.DebugString(); + return best_node; +} + +} // namespace raylet_scheduling_policy +} // namespace ray diff --git a/src/ray/raylet/scheduling/scheduling_policy.h b/src/ray/raylet/scheduling/scheduling_policy.h new file mode 100644 index 000000000..469296594 --- /dev/null +++ b/src/ray/raylet/scheduling/scheduling_policy.h @@ -0,0 +1,129 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "ray/common/ray_config.h" +#include "ray/gcs/gcs_client/gcs_client.h" +#include "ray/raylet/scheduling/cluster_resource_data.h" + +namespace ray { +namespace raylet_scheduling_policy { + +class SchedulingPolicy { + public: + SchedulingPolicy(scheduling::NodeID local_node_id, + const absl::flat_hash_map &nodes) + : local_node_id_(local_node_id), + nodes_(nodes), + gen_(std::chrono::high_resolution_clock::now().time_since_epoch().count()) {} + + /// This scheduling policy was designed with the following assumptions in mind: + /// 1. Scheduling a task on a new node incurs a cold start penalty (warming the worker + /// pool). + /// 2. Past a certain utilization threshold, a big noisy neighbor problem occurs + /// (caused by object spilling). + /// 3. Locality is helpful, but generally outweighed by (1) and (2). + /// + /// In order to solve these problems, we use the following scheduling policy. + /// 1. Generate a traversal. + /// 2. Run a priority scheduler. + /// + /// A node's priorities are determined by the following factors: + /// * Always skip infeasible nodes + /// * Always prefer available nodes over feasible nodes. + /// * Break ties in available/feasible by critical resource utilization. + /// * Critical resource utilization below a threshold should be truncated to 0. + /// + /// The traversal order should: + /// * Prioritize the local node above all others. + /// * All other nodes should have a globally fixed priority across the cluster. + /// + /// We call this a hybrid policy because below the threshold, the traversal and + /// truncation properties will lead to packing of nodes. Above the threshold, the policy + /// will act like a traditional weighted round robin. + /// + /// \param resource_request: The resource request we're attempting to schedule. + /// \param scheduler_avoid_gpu_nodes: if set, we would try scheduling + /// CPU-only requests on CPU-only nodes, and will fallback to scheduling on GPU nodes if + /// needed. + /// + /// \return -1 if the task is unfeasible, otherwise the node id (key in `nodes`) to + /// schedule on. + scheduling::NodeID HybridPolicy( + const ResourceRequest &resource_request, + float spread_threshold, + bool force_spillback, + bool require_available, + std::function is_node_available, + bool scheduler_avoid_gpu_nodes = RayConfig::instance().scheduler_avoid_gpu_nodes()); + + /// Round robin among available nodes. + /// If there are no available nodes, fallback to hybrid policy. + scheduling::NodeID SpreadPolicy( + const ResourceRequest &resource_request, + bool force_spillback, + bool require_available, + std::function is_node_available); + + /// Policy that "randomly" picks a node that could fulfil the request. + /// TODO(scv119): if there are a lot of nodes died or can't fulfill the resource + /// requirement, the distribution might not be even. + scheduling::NodeID RandomPolicy( + const ResourceRequest &resource_request, + std::function is_node_available); + + private: + /// Identifier of local node. + const scheduling::NodeID local_node_id_; + /// List of nodes in the clusters and their resources organized as a map. + /// The key of the map is the node ID. + const absl::flat_hash_map &nodes_; + // The node to start round robin if it's spread scheduling. + // The index may be inaccurate when nodes are added or removed dynamically, + // but it should still be better than always scanning from 0 for spread scheduling. + size_t spread_scheduling_next_index_ = 0; + /// Internally maintained random number generator. + std::mt19937_64 gen_; + + enum class NodeFilter { + /// Default scheduling. + kAny, + /// Schedule on GPU only nodes. + kGPU, + /// Schedule on nodes that don't have GPU. Since GPUs are more scarce resources, we + /// need + /// special handling for this. + kNonGpu + }; + + /// \param resource_request: The resource request we're attempting to schedule. + /// \param node_filter: defines the subset of nodes were are allowed to schedule on. + /// can be one of kAny (can schedule on all nodes), kGPU (can only schedule on kGPU + /// nodes), kNonGpu (can only schedule on non-GPU nodes. + /// + /// \return -1 if the task is unfeasible, otherwise the node id (key in `nodes`) to + /// schedule on. + scheduling::NodeID HybridPolicyWithFilter( + const ResourceRequest &resource_request, + float spread_threshold, + bool force_spillback, + bool require_available, + std::function is_node_available, + NodeFilter node_filter = NodeFilter::kAny); +}; +} // namespace raylet_scheduling_policy +} // namespace ray diff --git a/src/ray/raylet/scheduling/scheduling_policy_test.cc b/src/ray/raylet/scheduling/scheduling_policy_test.cc new file mode 100644 index 000000000..d77f7fe82 --- /dev/null +++ b/src/ray/raylet/scheduling/scheduling_policy_test.cc @@ -0,0 +1,429 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/raylet/scheduling/scheduling_policy.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace ray { + +namespace raylet { + +using ::testing::_; + +NodeResources CreateNodeResources(double available_cpu, + double total_cpu, + double available_memory, + double total_memory, + double available_gpu, + double total_gpu) { + NodeResources resources; + resources.predefined_resources = {{available_cpu, total_cpu}, + {available_memory, total_memory}, + {available_gpu, total_gpu}}; + return resources; +} + +class SchedulingPolicyTest : public ::testing::Test { + public: + scheduling::NodeID local_node = scheduling::NodeID(0); + scheduling::NodeID remote_node = scheduling::NodeID(1); + scheduling::NodeID remote_node_2 = scheduling::NodeID(2); + scheduling::NodeID remote_node_3 = scheduling::NodeID(3); + absl::flat_hash_map nodes; +}; + +TEST_F(SchedulingPolicyTest, SpreadPolicyTest) { + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}}, false); + + nodes.emplace(local_node, CreateNodeResources(20, 20, 0, 0, 0, 0)); + // Unavailable node + nodes.emplace(remote_node, CreateNodeResources(0, 20, 0, 0, 0, 0)); + // Infeasible node + nodes.emplace(remote_node_2, CreateNodeResources(0, 0, 0, 0, 0, 0)); + nodes.emplace(remote_node_3, CreateNodeResources(20, 20, 0, 0, 0, 0)); + + raylet_scheduling_policy::SchedulingPolicy scheduling_policy(local_node, nodes); + + auto to_schedule = + scheduling_policy.SpreadPolicy(req, false, false, [](auto) { return true; }); + ASSERT_EQ(to_schedule, local_node); + + to_schedule = + scheduling_policy.SpreadPolicy(req, false, false, [](auto) { return true; }); + ASSERT_EQ(to_schedule, remote_node_3); + + to_schedule = scheduling_policy.SpreadPolicy( + req, /*force_spillback=*/true, false, [](auto) { return true; }); + ASSERT_EQ(to_schedule, remote_node_3); +} + +TEST_F(SchedulingPolicyTest, RandomPolicyTest) { + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}}, false); + + nodes.emplace(local_node, CreateNodeResources(20, 20, 0, 0, 0, 0)); + nodes.emplace(remote_node, CreateNodeResources(20, 20, 0, 0, 0, 0)); + // Unavailable node + nodes.emplace(remote_node_2, CreateNodeResources(0, 20, 0, 0, 0, 0)); + // Infeasible node + nodes.emplace(remote_node_3, CreateNodeResources(0, 0, 0, 0, 0, 0)); + + raylet_scheduling_policy::SchedulingPolicy scheduling_policy(local_node, nodes); + + std::map decisions; + size_t num_node_0_picks = 0; + size_t num_node_1_picks = 0; + for (int i = 0; i < 1000; i++) { + auto to_schedule = scheduling_policy.RandomPolicy(req, [](auto) { return true; }); + ASSERT_TRUE(to_schedule.ToInt() >= 0); + ASSERT_TRUE(to_schedule.ToInt() <= 1); + if (to_schedule.ToInt() == 0) { + num_node_0_picks++; + } else { + num_node_1_picks++; + } + } + // It's extremely unlikely the only node 0 or node 1 is picked for 1000 runs. + ASSERT_TRUE(num_node_0_picks > 0); + ASSERT_TRUE(num_node_1_picks > 0); +} + +TEST_F(SchedulingPolicyTest, FeasibleDefinitionTest) { + auto task_req1 = + ResourceMapToResourceRequest({{"CPU", 1}, {"object_store_memory", 1}}, false); + auto task_req2 = ResourceMapToResourceRequest({{"CPU", 1}}, false); + { + // Don't break with a non-resized predefined resources array. + NodeResources resources; + resources.predefined_resources = {{0, 2.0}}; + ASSERT_FALSE(resources.IsFeasible(task_req1)); + ASSERT_TRUE(resources.IsFeasible(task_req2)); + } + + { + // After resizing, make sure it doesn't break under with resources with 0 total. + NodeResources resources; + resources.predefined_resources = {{0, 2.0}}; + resources.predefined_resources.resize(PredefinedResources_MAX); + ASSERT_FALSE(resources.IsFeasible(task_req1)); + ASSERT_TRUE(resources.IsFeasible(task_req2)); + } +} + +TEST_F(SchedulingPolicyTest, AvailableDefinitionTest) { + auto task_req1 = + ResourceMapToResourceRequest({{"CPU", 1}, {"object_store_memory", 1}}, false); + auto task_req2 = ResourceMapToResourceRequest({{"CPU", 1}}, false); + { + // Don't break with a non-resized predefined resources array. + NodeResources resources; + resources.predefined_resources = {{2, 2.0}}; + ASSERT_FALSE(resources.IsAvailable(task_req1)); + ASSERT_TRUE(resources.IsAvailable(task_req2)); + } + + { + // After resizing, make sure it doesn't break under with resources with 0 total. + NodeResources resources; + resources.predefined_resources = {{2, 2.0}}; + resources.predefined_resources.resize(PredefinedResources_MAX); + ASSERT_FALSE(resources.IsAvailable(task_req1)); + ASSERT_TRUE(resources.IsAvailable(task_req2)); + } +} + +TEST_F(SchedulingPolicyTest, CriticalResourceUtilizationDefinitionTest) { + { + // Don't break with a non-resized predefined resources array. + NodeResources resources; + resources.predefined_resources = {{1.0, 2.0}}; + ASSERT_EQ(resources.CalculateCriticalResourceUtilization(), 0.5); + } + + { + // After resizing, make sure it doesn't break under with resources with 0 total. + NodeResources resources; + resources.predefined_resources = {{1.0, 2.0}}; + resources.predefined_resources.resize(PredefinedResources_MAX); + ASSERT_EQ(resources.CalculateCriticalResourceUtilization(), 0.5); + } + + { + // Basic test of max + NodeResources resources; + resources.predefined_resources = {/* CPU */ {1.0, 2.0}, + /* MEM */ {0.25, 1}, + /* GPU (skipped) */ {1, 2}, + /* OBJECT_STORE_MEM*/ {50, 100}}; + resources.predefined_resources.resize(PredefinedResources_MAX); + ASSERT_EQ(resources.CalculateCriticalResourceUtilization(), 0.75); + } + + { + // Skip GPU + NodeResources resources; + resources.predefined_resources = {/* CPU */ {1.0, 2.0}, + /* MEM */ {0.25, 1}, + /* GPU (skipped) */ {0, 2}, + /* OBJECT_STORE_MEM*/ {50, 100}}; + resources.predefined_resources.resize(PredefinedResources_MAX); + ASSERT_EQ(resources.CalculateCriticalResourceUtilization(), 0.75); + } +} + +TEST_F(SchedulingPolicyTest, AvailableTruncationTest) { + // In this test, the local node and a remote node are both available. The remote node + // has a lower critical resource utilization, but they're both truncated to 0, so we + // should still pick the local node (due to traversal order). + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}}, false); + + nodes.emplace(local_node, CreateNodeResources(1, 2, 0, 0, 0, 0)); + nodes.emplace(remote_node, CreateNodeResources(0.75, 2, 0, 0, 0, 0)); + + auto to_schedule = + raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy(req, 0.51, false, false, [](auto) { return true; }); + ASSERT_EQ(to_schedule, local_node); +} + +TEST_F(SchedulingPolicyTest, AvailableTieBreakTest) { + // In this test, the local node and a remote node are both available. The remote node + // has a lower critical resource utilization so we schedule on it. + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}}, false); + + nodes.emplace(local_node, CreateNodeResources(1, 2, 0, 0, 0, 0)); + nodes.emplace(remote_node, CreateNodeResources(1.5, 2, 0, 0, 0, 0)); + + auto to_schedule = + raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy(req, 0.50, false, false, [](auto) { return true; }); + ASSERT_EQ(to_schedule, remote_node); +} + +TEST_F(SchedulingPolicyTest, AvailableOverFeasibleTest) { + // In this test, the local node is feasible and has a lower critical resource + // utilization, but the remote node can run the task immediately, so we pick the remote + // node. + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}, {"GPU", 1}}, false); + nodes.emplace(local_node, CreateNodeResources(10, 10, 0, 0, 0, 1)); + nodes.emplace(remote_node, CreateNodeResources(1, 10, 0, 0, 1, 1)); + + auto to_schedule = + raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy(req, 0.50, false, false, [](auto) { return true; }); + ASSERT_EQ(to_schedule, remote_node); +} + +TEST_F(SchedulingPolicyTest, InfeasibleTest) { + // All the nodes are infeasible, so we return -1. + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}, {"GPU", 1}}, false); + nodes.emplace(local_node, CreateNodeResources(10, 10, 0, 0, 0, 0)); + nodes.emplace(remote_node, CreateNodeResources(1, 10, 0, 0, 0, 0)); + + auto to_schedule = + raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy(req, 0.50, false, false, [](auto) { return true; }); + ASSERT_TRUE(to_schedule.IsNil()); +} + +TEST_F(SchedulingPolicyTest, BarelyFeasibleTest) { + // Test the edge case where a task requires all of a node's resources, and the node is + // fully utilized. + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}, {"GPU", 1}}, false); + + nodes.emplace(local_node, CreateNodeResources(0, 1, 0, 0, 0, 1)); + + auto to_schedule = + raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy(req, 0.50, false, false, [](auto) { return true; }); + ASSERT_EQ(to_schedule, local_node); +} + +TEST_F(SchedulingPolicyTest, TruncationAcrossFeasibleNodesTest) { + // Same as AvailableTruncationTest except now none of the nodes are available, but the + // tie break logic should apply to feasible nodes too. + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}, {"GPU", 1}}, false); + nodes.emplace(local_node, CreateNodeResources(1, 2, 0, 0, 0, 1)); + nodes.emplace(remote_node, CreateNodeResources(0.75, 2, 0, 0, 0, 1)); + + auto to_schedule = + raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy(req, 0.51, false, false, [](auto) { return true; }); + ASSERT_EQ(to_schedule, local_node); +} + +TEST_F(SchedulingPolicyTest, ForceSpillbackIfAvailableTest) { + // The local node is better, but we force spillback, so we'll schedule on a non-local + // node anyways. + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}, {"GPU", 1}}, false); + nodes.emplace(local_node, CreateNodeResources(2, 2, 0, 0, 1, 1)); + nodes.emplace(remote_node, CreateNodeResources(1, 10, 0, 0, 1, 10)); + + auto to_schedule = raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy(req, 0.51, true, true, [](auto) { return true; }); + ASSERT_EQ(to_schedule, remote_node); +} + +TEST_F(SchedulingPolicyTest, AvoidSchedulingCPURequestsOnGPUNodes) { + nodes.emplace(local_node, CreateNodeResources(10, 10, 0, 0, 1, 1)); + nodes.emplace(remote_node, CreateNodeResources(1, 2, 0, 0, 0, 0)); + + { + // The local node is better, but it has GPUs, the request is + // non GPU, and the remote node does not have GPUs, thus + // we should schedule on remote node. + const ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}}, false); + const auto to_schedule = raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy( + ResourceMapToResourceRequest({{"CPU", 1}}, false), + 0.51, + false, + true, + [](auto) { return true; }, + true); + ASSERT_EQ(to_schedule, remote_node); + } + { + // A GPU request should be scheduled on a GPU node. + const ResourceRequest req = ResourceMapToResourceRequest({{"GPU", 1}}, false); + const auto to_schedule = + raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy( + req, 0.51, false, true, [](auto) { return true; }, true); + ASSERT_EQ(to_schedule, local_node); + } + { + // A CPU request can be be scheduled on a CPU node. + const ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}}, false); + const auto to_schedule = + raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy( + req, 0.51, false, true, [](auto) { return true; }, true); + ASSERT_EQ(to_schedule, remote_node); + } + { + // A mixed CPU/GPU request should be scheduled on a GPU node. + const ResourceRequest req = + ResourceMapToResourceRequest({{"CPU", 1}, {"GPU", 1}}, false); + const auto to_schedule = + raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy( + req, 0.51, false, true, [](auto) { return true; }, true); + ASSERT_EQ(to_schedule, local_node); + } +} + +TEST_F(SchedulingPolicyTest, SchedulenCPURequestsOnGPUNodeAsALastResort) { + // Schedule on remote node, even though the request is CPU only, because + // we can not schedule on CPU nodes. + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}}, false); + nodes.emplace(local_node, CreateNodeResources(0, 10, 0, 0, 0, 0)); + nodes.emplace(remote_node, CreateNodeResources(1, 1, 0, 0, 1, 1)); + + const auto to_schedule = + raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy( + req, 0.51, false, true, [](auto) { return true; }, true); + ASSERT_EQ(to_schedule, remote_node); +} + +TEST_F(SchedulingPolicyTest, ForceSpillbackTest) { + // The local node is available but disqualified. + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}, {"GPU", 1}}, false); + + nodes.emplace(local_node, CreateNodeResources(2, 2, 0, 0, 1, 1)); + nodes.emplace(remote_node, CreateNodeResources(0, 2, 0, 0, 0, 1)); + + auto to_schedule = raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy(req, 0.51, true, false, [](auto) { return true; }); + ASSERT_EQ(to_schedule, remote_node); +} + +TEST_F(SchedulingPolicyTest, ForceSpillbackOnlyFeasibleLocallyTest) { + // The local node is better, but we force spillback, so we'll schedule on a non-local + // node anyways. + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}, {"GPU", 1}}, false); + + nodes.emplace(local_node, CreateNodeResources(2, 2, 0, 0, 1, 1)); + nodes.emplace(remote_node, CreateNodeResources(0, 2, 0, 0, 0, 0)); + + auto to_schedule = raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy(req, 0.51, true, false, [](auto) { return true; }); + ASSERT_TRUE(to_schedule.IsNil()); +} + +TEST_F(SchedulingPolicyTest, NonGpuNodePreferredSchedulingTest) { + // Prefer to schedule on CPU nodes first. + // GPU nodes should be preferred as a last resort. + + // local {CPU:2, GPU:1} + // Remote {CPU: 2} + nodes.emplace(local_node, CreateNodeResources(2, 2, 0, 0, 1, 1)); + nodes.emplace(remote_node, CreateNodeResources(2, 2, 0, 0, 0, 0)); + nodes.emplace(remote_node_2, CreateNodeResources(3, 3, 0, 0, 0, 0)); + + ResourceRequest req = ResourceMapToResourceRequest({{"CPU", 1}}, false); + auto to_schedule = raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy( + req, + 0.51, + false, + true, + [](auto) { return true; }, + /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, remote_node); + + req = ResourceMapToResourceRequest({{"CPU", 3}}, false); + to_schedule = raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy( + req, + 0.51, + false, + true, + [](auto) { return true; }, + /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, remote_node_2); + + req = ResourceMapToResourceRequest({{"CPU", 1}, {"GPU", 1}}, false); + to_schedule = raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy( + req, + 0.51, + false, + true, + [](auto) { return true; }, + /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, local_node); + + req = ResourceMapToResourceRequest({{"CPU", 2}}, false); + to_schedule = raylet_scheduling_policy::SchedulingPolicy(local_node, nodes) + .HybridPolicy( + req, + 0.51, + false, + true, + [](auto) { return true; }, + /*gpu_avoid_scheduling*/ true); + ASSERT_EQ(to_schedule, remote_node); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace raylet + +} // namespace ray diff --git a/src/ray/raylet/test/local_object_manager_test.cc b/src/ray/raylet/test/local_object_manager_test.cc index be7945203..a980b3ee9 100644 --- a/src/ray/raylet/test/local_object_manager_test.cc +++ b/src/ray/raylet/test/local_object_manager_test.cc @@ -38,7 +38,8 @@ class MockSubscriber : public pubsub::SubscriberInterface { public: bool Subscribe( const std::unique_ptr sub_message, - const rpc::ChannelType channel_type, const rpc::Address &owner_address, + const rpc::ChannelType channel_type, + const rpc::Address &owner_address, const std::string &key_id_binary, pubsub::SubscribeDoneCallback subscribe_done_callback, pubsub::SubscriptionItemCallback subscription_callback, @@ -83,16 +84,19 @@ class MockSubscriber : public pubsub::SubscriberInterface { pubsub::SubscriptionItemCallback subscription_callback, pubsub::SubscriptionFailureCallback subscription_failure_callback)); - MOCK_METHOD3(Unsubscribe, bool(const rpc::ChannelType channel_type, - const rpc::Address &publisher_address, - const std::string &key_id_binary)); + MOCK_METHOD3(Unsubscribe, + bool(const rpc::ChannelType channel_type, + const rpc::Address &publisher_address, + const std::string &key_id_binary)); - MOCK_METHOD2(UnsubscribeChannel, bool(const rpc::ChannelType channel_type, - const rpc::Address &publisher_address)); + MOCK_METHOD2(UnsubscribeChannel, + bool(const rpc::ChannelType channel_type, + const rpc::Address &publisher_address)); - MOCK_CONST_METHOD3(IsSubscribed, bool(const rpc::ChannelType channel_type, - const rpc::Address &publisher_address, - const std::string &key_id_binary)); + MOCK_CONST_METHOD3(IsSubscribed, + bool(const rpc::ChannelType channel_type, + const rpc::Address &publisher_address, + const std::string &key_id_binary)); MOCK_CONST_METHOD0(DebugString, std::string()); @@ -202,7 +206,8 @@ class MockIOWorkerClient : public rpc::CoreWorkerClientInterface { class MockIOWorker : public MockWorker { public: - MockIOWorker(WorkerID worker_id, int port, + MockIOWorker(WorkerID worker_id, + int port, std::shared_ptr io_worker) : MockWorker(worker_id, port), io_worker(io_worker) {} @@ -264,7 +269,8 @@ class MockIOWorkerPool : public IOWorkerPoolInterface { class MockObjectBuffer : public Buffer { public: - MockObjectBuffer(size_t size, ObjectID object_id, + MockObjectBuffer(size_t size, + ObjectID object_id, std::shared_ptr> unpins) : size_(size), id_(object_id), unpins_(unpins) {} @@ -292,8 +298,13 @@ class LocalObjectManagerTest : public ::testing::Test { manager_node_id_(NodeID::FromRandom()), max_fused_object_count_(15), manager( - manager_node_id_, "address", 1234, free_objects_batch_size, - /*free_objects_period_ms=*/1000, worker_pool, client_pool, + manager_node_id_, + "address", + 1234, + free_objects_batch_size, + /*free_objects_period_ms=*/1000, + worker_pool, + client_pool, /*max_io_workers=*/2, /*min_spilling_size=*/0, /*is_external_storage_type_fs=*/true, @@ -362,8 +373,8 @@ TEST_F(LocalObjectManagerTest, TestPin) { std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); - auto object = std::make_unique(nullptr, meta_buffer, - std::vector()); + auto object = std::make_unique( + nullptr, meta_buffer, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -388,8 +399,8 @@ TEST_F(LocalObjectManagerTest, TestRestoreSpilledObject) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -445,8 +456,8 @@ TEST_F(LocalObjectManagerTest, TestExplicitSpill) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -491,8 +502,8 @@ TEST_F(LocalObjectManagerTest, TestDuplicateSpill) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -546,8 +557,8 @@ TEST_F(LocalObjectManagerTest, TestSpillObjectsOfSize) { object_ids.push_back(object_id); auto data_buffer = std::make_shared(object_size, object_id, unpins); total_size += object_size; - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -616,8 +627,8 @@ TEST_F(LocalObjectManagerTest, TestSpillUptoMaxFuseCount) { object_ids.push_back(object_id); auto data_buffer = std::make_shared(object_size, object_id, unpins); total_size += object_size; - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -663,8 +674,8 @@ TEST_F(LocalObjectManagerTest, TestSpillObjectNotEvictable) { unevictable_objects_.emplace(object_id); auto data_buffer = std::make_shared(object_size, object_id, unpins); total_size += object_size; - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -693,8 +704,8 @@ TEST_F(LocalObjectManagerTest, TestSpillUptoMaxThroughput) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(object_size, object_id, unpins); - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -764,8 +775,8 @@ TEST_F(LocalObjectManagerTest, TestSpillError) { ObjectID object_id = ObjectID::FromRandom(); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(std::move(data_buffer), nullptr, - std::vector()); + auto object = std::make_unique( + std::move(data_buffer), nullptr, std::vector()); std::vector> objects; objects.push_back(std::move(object)); @@ -812,8 +823,8 @@ TEST_F(LocalObjectManagerTest, TestPartialSpillError) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -849,8 +860,8 @@ TEST_F(LocalObjectManagerTest, TestDeleteNoSpilledObjects) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(std::move(data_buffer), nullptr, - std::vector()); + auto object = std::make_unique( + std::move(data_buffer), nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -877,8 +888,8 @@ TEST_F(LocalObjectManagerTest, TestDeleteSpilledObjects) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -926,8 +937,8 @@ TEST_F(LocalObjectManagerTest, TestDeleteURLRefCount) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -987,8 +998,8 @@ TEST_F(LocalObjectManagerTest, TestDeleteSpillingObjectsBlocking) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -1064,8 +1075,8 @@ TEST_F(LocalObjectManagerTest, TestDeleteMaxObjects) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -1116,8 +1127,8 @@ TEST_F(LocalObjectManagerTest, TestDeleteURLRefCountRaceCondition) { ObjectID object_id = ObjectID::FromRandom(); object_ids.push_back(object_id); auto data_buffer = std::make_shared(0, object_id, unpins); - auto object = std::make_unique(data_buffer, nullptr, - std::vector()); + auto object = std::make_unique( + data_buffer, nullptr, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -1177,8 +1188,8 @@ TEST_F(LocalObjectManagerTest, TestDuplicatePin) { std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); - auto object = std::make_unique(nullptr, meta_buffer, - std::vector()); + auto object = std::make_unique( + nullptr, meta_buffer, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -1190,8 +1201,8 @@ TEST_F(LocalObjectManagerTest, TestDuplicatePin) { std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); - auto object = std::make_unique(nullptr, meta_buffer, - std::vector()); + auto object = std::make_unique( + nullptr, meta_buffer, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -1202,8 +1213,8 @@ TEST_F(LocalObjectManagerTest, TestDuplicatePin) { std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); - auto object = std::make_unique(nullptr, meta_buffer, - std::vector()); + auto object = std::make_unique( + nullptr, meta_buffer, std::vector()); objects.push_back(std::move(object)); } rpc::Address owner_address2; @@ -1241,8 +1252,8 @@ TEST_F(LocalObjectManagerTest, TestDuplicatePinAndSpill) { std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); - auto object = std::make_unique(nullptr, meta_buffer, - std::vector()); + auto object = std::make_unique( + nullptr, meta_buffer, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); @@ -1270,8 +1281,8 @@ TEST_F(LocalObjectManagerTest, TestDuplicatePinAndSpill) { std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); - auto object = std::make_unique(nullptr, meta_buffer, - std::vector()); + auto object = std::make_unique( + nullptr, meta_buffer, std::vector()); objects.push_back(std::move(object)); } manager.PinObjectsAndWaitForFree(object_ids, std::move(objects), owner_address); diff --git a/src/ray/raylet/wait_manager.cc b/src/ray/raylet/wait_manager.cc index fcf47c324..5c4667879 100644 --- a/src/ray/raylet/wait_manager.cc +++ b/src/ray/raylet/wait_manager.cc @@ -19,8 +19,10 @@ namespace ray { namespace raylet { -void WaitManager::Wait(const std::vector &object_ids, int64_t timeout_ms, - uint64_t num_required_objects, const WaitCallback &callback) { +void WaitManager::Wait(const std::vector &object_ids, + int64_t timeout_ms, + uint64_t num_required_objects, + const WaitCallback &callback) { RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1); RAY_CHECK_NE(num_required_objects, 0u); RAY_CHECK_LE(num_required_objects, object_ids.size()); diff --git a/src/ray/raylet/wait_manager.h b/src/ray/raylet/wait_manager.h index 310897610..bc4d79ab9 100644 --- a/src/ray/raylet/wait_manager.h +++ b/src/ray/raylet/wait_manager.h @@ -42,8 +42,10 @@ class WaitManager { /// invoking the callback. /// \param callback Invoked when either timeout_ms is satisfied OR num_required_objects /// is satisfied. - void Wait(const std::vector &object_ids, int64_t timeout_ms, - uint64_t num_required_objects, const WaitCallback &callback); + void Wait(const std::vector &object_ids, + int64_t timeout_ms, + uint64_t num_required_objects, + const WaitCallback &callback); /// This is invoked whenever an object becomes locally available. /// @@ -55,8 +57,10 @@ class WaitManager { private: struct WaitRequest { - WaitRequest(int64_t timeout_ms, const WaitCallback &callback, - const std::vector &object_ids, uint64_t num_required_objects) + WaitRequest(int64_t timeout_ms, + const WaitCallback &callback, + const std::vector &object_ids, + uint64_t num_required_objects) : timeout_ms(timeout_ms), callback(callback), object_ids(object_ids), diff --git a/src/ray/raylet/wait_manager_test.cc b/src/ray/raylet/wait_manager_test.cc index bb4260b62..7293b35c1 100644 --- a/src/ray/raylet/wait_manager_test.cc +++ b/src/ray/raylet/wait_manager_test.cc @@ -48,7 +48,9 @@ TEST_F(WaitManagerTest, TestImmediatelyCompleteWait) { local_objects.emplace(obj1); std::vector ready; std::vector remaining; - wait_manager.Wait(std::vector{obj1, obj2}, -1, 1, + wait_manager.Wait(std::vector{obj1, obj2}, + -1, + 1, [&](std::vector _ready, std::vector _remaining) { ready = _ready; remaining = _remaining; @@ -61,7 +63,9 @@ TEST_F(WaitManagerTest, TestImmediatelyCompleteWait) { ready.clear(); remaining.clear(); // The wait should immediately complete since the timeout is 0. - wait_manager.Wait(std::vector{obj1, obj2}, 0, 1, + wait_manager.Wait(std::vector{obj1, obj2}, + 0, + 1, [&](std::vector _ready, std::vector _remaining) { ready = _ready; remaining = _remaining; @@ -79,12 +83,16 @@ TEST_F(WaitManagerTest, TestMultiWaits) { std::vector remaining1; std::vector ready2; std::vector remaining2; - wait_manager.Wait(std::vector{obj1}, -1, 1, + wait_manager.Wait(std::vector{obj1}, + -1, + 1, [&](std::vector _ready, std::vector _remaining) { ready1 = _ready; remaining1 = _remaining; }); - wait_manager.Wait(std::vector{obj1}, -1, 1, + wait_manager.Wait(std::vector{obj1}, + -1, + 1, [&](std::vector _ready, std::vector _remaining) { ready2 = _ready; remaining2 = _remaining; @@ -109,7 +117,9 @@ TEST_F(WaitManagerTest, TestWaitTimeout) { ObjectID obj1 = ObjectID::FromRandom(); std::vector ready; std::vector remaining; - wait_manager.Wait(std::vector{obj1}, 1, 1, + wait_manager.Wait(std::vector{obj1}, + 1, + 1, [&](std::vector _ready, std::vector _remaining) { ready = _ready; remaining = _remaining; diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index d55501e9e..cac5dff73 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -26,11 +26,15 @@ namespace ray { namespace raylet { /// A constructor responsible for initializing the state of a worker. -Worker::Worker(const JobID &job_id, const int runtime_env_hash, const WorkerID &worker_id, - const Language &language, rpc::WorkerType worker_type, +Worker::Worker(const JobID &job_id, + const int runtime_env_hash, + const WorkerID &worker_id, + const Language &language, + rpc::WorkerType worker_type, const std::string &ip_address, std::shared_ptr connection, - rpc::ClientCallManager &client_call_manager, StartupToken startup_token) + rpc::ClientCallManager &client_call_manager, + StartupToken startup_token) : worker_id_(worker_id), startup_token_(startup_token), language_(language), diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index dce852a66..611a5d706 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -123,10 +123,15 @@ class Worker : public WorkerInterface { public: /// A constructor that initializes a worker object. /// NOTE: You MUST manually set the worker process. - Worker(const JobID &job_id, const int runtime_env_hash, const WorkerID &worker_id, - const Language &language, rpc::WorkerType worker_type, - const std::string &ip_address, std::shared_ptr connection, - rpc::ClientCallManager &client_call_manager, StartupToken startup_token); + Worker(const JobID &job_id, + const int runtime_env_hash, + const WorkerID &worker_id, + const Language &language, + rpc::WorkerType worker_type, + const std::string &ip_address, + std::shared_ptr connection, + rpc::ClientCallManager &client_call_manager, + StartupToken startup_token); /// A destructor responsible for freeing all worker state. ~Worker() {} rpc::WorkerType GetWorkerType() const; diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 669351028..699fc8e66 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -29,8 +29,11 @@ #include "ray/util/logging.h" #include "ray/util/util.h" -DEFINE_stats(worker_register_time_ms, "end to end latency of register a worker process.", - (), ({1, 10, 100, 1000, 10000}, ), ray::stats::HISTOGRAM); +DEFINE_stats(worker_register_time_ms, + "end to end latency of register a worker process.", + (), + ({1, 10, 100, 1000, 10000}, ), + ray::stats::HISTOGRAM); namespace { @@ -64,16 +67,21 @@ namespace ray { namespace raylet { -WorkerPool::WorkerPool(instrumented_io_context &io_service, const NodeID node_id, - const std::string node_address, int num_workers_soft_limit, +WorkerPool::WorkerPool(instrumented_io_context &io_service, + const NodeID node_id, + const std::string node_address, + int num_workers_soft_limit, int num_initial_python_workers_for_first_job, - int maximum_startup_concurrency, int min_worker_port, - int max_worker_port, const std::vector &worker_ports, + int maximum_startup_concurrency, + int min_worker_port, + int max_worker_port, + const std::vector &worker_ports, std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands, const std::string &native_library_path, std::function starting_worker_timeout_callback, - int ray_debugger_external, const std::function get_time) + int ray_debugger_external, + const std::function get_time) : worker_startup_token_counter_(0), io_service_(&io_service), node_id_(node_id), @@ -177,9 +185,11 @@ void WorkerPool::PopWorkerCallbackAsync(const PopWorkerCallback &callback, // invoking the callback immediately. RAY_CHECK(status != PopWorkerStatus::RuntimeEnvCreationFailed); // Call back this function asynchronously to make sure executed in different stack. - io_service_->post([this, callback, worker, - status]() { PopWorkerCallbackInternal(callback, worker, status); }, - "WorkerPool.PopWorkerCallback"); + io_service_->post( + [this, callback, worker, status]() { + PopWorkerCallbackInternal(callback, worker, status); + }, + "WorkerPool.PopWorkerCallback"); } void WorkerPool::PopWorkerCallbackInternal(const PopWorkerCallback &callback, @@ -198,13 +208,19 @@ void WorkerPool::update_worker_startup_token_counter() { } void WorkerPool::AddStartingWorkerProcess( - State &state, const int workers_to_start, const rpc::WorkerType worker_type, - const Process &proc, const std::chrono::high_resolution_clock::time_point &start, + State &state, + const int workers_to_start, + const rpc::WorkerType worker_type, + const Process &proc, + const std::chrono::high_resolution_clock::time_point &start, const rpc::RuntimeEnvInfo &runtime_env_info) { - state.starting_worker_processes.emplace( - worker_startup_token_counter_, - StartingWorkerProcessInfo{workers_to_start, workers_to_start, worker_type, proc, - start, runtime_env_info}); + state.starting_worker_processes.emplace(worker_startup_token_counter_, + StartingWorkerProcessInfo{workers_to_start, + workers_to_start, + worker_type, + proc, + start, + runtime_env_info}); runtime_env_manager_.AddURIReference( kWorkerSetupTokenPrefix + std::to_string(worker_startup_token_counter_), runtime_env_info); @@ -218,9 +234,13 @@ void WorkerPool::RemoveStartingWorkerProcess(State &state, } std::tuple WorkerPool::StartWorkerProcess( - const Language &language, const rpc::WorkerType worker_type, const JobID &job_id, - PopWorkerStatus *status, const std::vector &dynamic_options, - const int runtime_env_hash, const std::string &serialized_runtime_env_context, + const Language &language, + const rpc::WorkerType worker_type, + const JobID &job_id, + PopWorkerStatus *status, + const std::vector &dynamic_options, + const int runtime_env_hash, + const std::string &serialized_runtime_env_context, const std::string &allocated_instances_serialized_json, const rpc::RuntimeEnvInfo &runtime_env_info) { rpc::JobConfig *job_config = nullptr; @@ -299,7 +319,8 @@ std::tuple WorkerPool::StartWorkerProcess( // Append user-defined per-job options here if (language == Language::JAVA) { if (!job_config->jvm_options().empty()) { - options.insert(options.end(), job_config->jvm_options().begin(), + options.insert(options.end(), + job_config->jvm_options().begin(), job_config->jvm_options().end()); } } @@ -325,8 +346,8 @@ std::tuple WorkerPool::StartWorkerProcess( std::vector worker_command_args; for (auto const &token : state.worker_command) { if (token == kWorkerDynamicOptionPlaceholder) { - worker_command_args.insert(worker_command_args.end(), options.begin(), - options.end()); + worker_command_args.insert( + worker_command_args.end(), options.begin(), options.end()); continue; } RAY_CHECK(node_manager_port_ != 0) @@ -454,10 +475,10 @@ std::tuple WorkerPool::StartWorkerProcess( RAY_LOG(INFO) << "Started worker process of " << workers_to_start << " worker(s) with pid " << proc.GetId() << ", the token " << worker_startup_token_counter_; - MonitorStartingWorkerProcess(proc, worker_startup_token_counter_, language, - worker_type); - AddStartingWorkerProcess(state, workers_to_start, worker_type, proc, start, - runtime_env_info); + MonitorStartingWorkerProcess( + proc, worker_startup_token_counter_, language, worker_type); + AddStartingWorkerProcess( + state, workers_to_start, worker_type, proc, start, runtime_env_info); StartupToken worker_startup_token = worker_startup_token_counter_; update_worker_startup_token_counter(); if (IsIOWorkerType(worker_type)) { @@ -472,11 +493,12 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc, const Language &language, const rpc::WorkerType worker_type) { auto timer = std::make_shared( - *io_service_, boost::posix_time::seconds( - RayConfig::instance().worker_register_timeout_seconds())); + *io_service_, + boost::posix_time::seconds( + RayConfig::instance().worker_register_timeout_seconds())); // Capture timer in lambda to copy it once, so that it can avoid destructing timer. - timer->async_wait([timer, language, proc = proc, proc_startup_token, worker_type, - this](const boost::system::error_code e) mutable { + timer->async_wait([timer, language, proc = proc, proc_startup_token, worker_type, this]( + const boost::system::error_code e) mutable { // check the error code. auto &state = this->GetStateForLanguage(language); // Since this process times out to start, remove it from starting_worker_processes @@ -500,12 +522,20 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc, bool used; TaskID task_id; InvokePopWorkerCallbackForProcess(state.starting_dedicated_workers_to_tasks, - proc_startup_token, nullptr, status, &found, - &used, &task_id); + proc_startup_token, + nullptr, + status, + &found, + &used, + &task_id); if (!found) { InvokePopWorkerCallbackForProcess(state.starting_workers_to_tasks, - proc_startup_token, nullptr, status, &found, - &used, &task_id); + proc_startup_token, + nullptr, + status, + &found, + &used, + &task_id); } RemoveStartingWorkerProcess(state, proc_startup_token); if (IsIOWorkerType(worker_type)) { @@ -617,19 +647,22 @@ void WorkerPool::HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job runtime_env_manager_.AddURIReference(job_id.Hex(), job_config.runtime_env_info()); RAY_LOG(INFO) << "[Eagerly] Start install runtime environment for job " << job_id << ". The runtime environment was " << runtime_env << "."; - CreateRuntimeEnv( - runtime_env, job_id, - [job_id](bool successful, const std::string &serialized_runtime_env_context, - const std::string &setup_error_message) { - if (successful) { - RAY_LOG(INFO) << "[Eagerly] Create runtime env successful for job " << job_id - << ". The result context was " << serialized_runtime_env_context - << "."; - } else { - RAY_LOG(ERROR) << "[Eagerly] Couldn't create a runtime environment for job " - << job_id << ". Error message: " << setup_error_message; - } - }); + CreateRuntimeEnv(runtime_env, + job_id, + [job_id](bool successful, + const std::string &serialized_runtime_env_context, + const std::string &setup_error_message) { + if (successful) { + RAY_LOG(INFO) + << "[Eagerly] Create runtime env successful for job " + << job_id << ". The result context was " + << serialized_runtime_env_context << "."; + } else { + RAY_LOG(ERROR) + << "[Eagerly] Couldn't create a runtime environment for job " + << job_id << ". Error message: " << setup_error_message; + } + }); } } @@ -655,7 +688,8 @@ boost::optional WorkerPool::GetJobConfig( } Status WorkerPool::RegisterWorker(const std::shared_ptr &worker, - pid_t pid, StartupToken worker_startup_token, + pid_t pid, + StartupToken worker_startup_token, std::function send_reply_callback) { RAY_CHECK(worker); auto &state = GetStateForLanguage(worker->GetLanguage()); @@ -879,8 +913,12 @@ void WorkerPool::PopDeleteWorker( void WorkerPool::InvokePopWorkerCallbackForProcess( absl::flat_hash_map &starting_workers_to_tasks, - StartupToken startup_token, const std::shared_ptr &worker, - const PopWorkerStatus &status, bool *found, bool *worker_used, TaskID *task_id) { + StartupToken startup_token, + const std::shared_ptr &worker, + const PopWorkerStatus &status, + bool *found, + bool *worker_used, + TaskID *task_id) { *found = false; *worker_used = false; auto it = starting_workers_to_tasks.find(startup_token); @@ -907,8 +945,12 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { bool used; TaskID task_id; InvokePopWorkerCallbackForProcess(state.starting_dedicated_workers_to_tasks, - worker->GetStartupToken(), worker, - PopWorkerStatus::OK, &found, &used, &task_id); + worker->GetStartupToken(), + worker, + PopWorkerStatus::OK, + &found, + &used, + &task_id); if (found) { // The worker is used for the actor creation task with dynamic options. if (!used) { @@ -920,8 +962,12 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { } InvokePopWorkerCallbackForProcess(state.starting_workers_to_tasks, - worker->GetStartupToken(), worker, - PopWorkerStatus::OK, &found, &used, &task_id); + worker->GetStartupToken(), + worker, + PopWorkerStatus::OK, + &found, + &used, + &task_id); // The worker is not used for the actor creation task with dynamic options. if (!used) { // Put the worker to the idle pool. @@ -1033,35 +1079,38 @@ void WorkerPool::TryKillingIdleWorkers() { RAY_CHECK(running_size > 0); running_size--; rpc::ExitRequest request; - rpc_client->Exit(request, [this, worker](const ray::Status &status, - const rpc::ExitReply &r) { - RAY_CHECK(pending_exit_idle_workers_.erase(worker->WorkerId())); - if (!status.ok()) { - RAY_LOG(ERROR) << "Failed to send exit request: " << status.ToString(); - } + rpc_client->Exit( + request, [this, worker](const ray::Status &status, const rpc::ExitReply &r) { + RAY_CHECK(pending_exit_idle_workers_.erase(worker->WorkerId())); + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to send exit request: " << status.ToString(); + } - // In case of failed to send request, we remove it from pool as well - // TODO (iycheng): We should handle the grpc failure in better way. - if (!status.ok() || r.success()) { - auto &worker_state = GetStateForLanguage(worker->GetLanguage()); - // If we could kill the worker properly, we remove them from the idle pool. - RemoveWorker(worker_state.idle, worker); - // We always mark the worker as dead. - // If the worker is not idle at this moment, we'd want to mark it as dead - // so it won't be reused later. - if (!worker->IsDead()) { - worker->MarkDead(); - } - } else { - // We re-insert the idle worker to the back of the queue if it fails to kill - // the worker (e.g., when the worker owns the object). Without this, if the - // first N workers own objects, it can't kill idle workers that are >= N+1. - const auto &idle_pair = idle_of_all_languages_.front(); - idle_of_all_languages_.push_back(idle_pair); - idle_of_all_languages_.pop_front(); - RAY_CHECK(idle_of_all_languages_.size() == idle_of_all_languages_map_.size()); - } - }); + // In case of failed to send request, we remove it from pool as well + // TODO (iycheng): We should handle the grpc failure in better way. + if (!status.ok() || r.success()) { + auto &worker_state = GetStateForLanguage(worker->GetLanguage()); + // If we could kill the worker properly, we remove them from the idle + // pool. + RemoveWorker(worker_state.idle, worker); + // We always mark the worker as dead. + // If the worker is not idle at this moment, we'd want to mark it as dead + // so it won't be reused later. + if (!worker->IsDead()) { + worker->MarkDead(); + } + } else { + // We re-insert the idle worker to the back of the queue if it fails to + // kill the worker (e.g., when the worker owns the object). Without this, + // if the first N workers own objects, it can't kill idle workers that are + // >= N+1. + const auto &idle_pair = idle_of_all_languages_.front(); + idle_of_all_languages_.push_back(idle_pair); + idle_of_all_languages_.pop_front(); + RAY_CHECK(idle_of_all_languages_.size() == + idle_of_all_languages_map_.size()); + } + }); } else { // Even it's a dead worker, we still need to remove them from the pool. RemoveWorker(worker_state.idle, worker); @@ -1092,17 +1141,23 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, std::shared_ptr worker = nullptr; auto start_worker_process_fn = [this, allocated_instances_serialized_json]( - const TaskSpecification &task_spec, State &state, + const TaskSpecification &task_spec, + State &state, std::vector dynamic_options, bool dedicated, const std::string &serialized_runtime_env, const std::string &serialized_runtime_env_context, const PopWorkerCallback &callback) -> Process { PopWorkerStatus status = PopWorkerStatus::OK; - auto [proc, startup_token] = StartWorkerProcess( - task_spec.GetLanguage(), rpc::WorkerType::WORKER, task_spec.JobId(), &status, - dynamic_options, task_spec.GetRuntimeEnvHash(), serialized_runtime_env_context, - allocated_instances_serialized_json, task_spec.RuntimeEnvInfo()); + auto [proc, startup_token] = StartWorkerProcess(task_spec.GetLanguage(), + rpc::WorkerType::WORKER, + task_spec.JobId(), + &status, + dynamic_options, + task_spec.GetRuntimeEnvHash(), + serialized_runtime_env_context, + allocated_instances_serialized_json, + task_spec.RuntimeEnvInfo()); if (status == PopWorkerStatus::OK) { RAY_CHECK(proc.IsValid()); WarnAboutSize(); @@ -1144,17 +1199,24 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, if (task_spec.HasRuntimeEnv()) { // create runtime env. CreateRuntimeEnv( - task_spec.SerializedRuntimeEnv(), task_spec.JobId(), + task_spec.SerializedRuntimeEnv(), + task_spec.JobId(), [this, start_worker_process_fn, callback, &state, task_spec, dynamic_options]( - bool successful, const std::string &serialized_runtime_env_context, + bool successful, + const std::string &serialized_runtime_env_context, const std::string &setup_error_message) { if (successful) { - start_worker_process_fn(task_spec, state, dynamic_options, true, + start_worker_process_fn(task_spec, + state, + dynamic_options, + true, task_spec.SerializedRuntimeEnv(), - serialized_runtime_env_context, callback); + serialized_runtime_env_context, + callback); } else { process_failed_runtime_env_setup_failed_++; - callback(nullptr, PopWorkerStatus::RuntimeEnvCreationFailed, + callback(nullptr, + PopWorkerStatus::RuntimeEnvCreationFailed, /*runtime_env_setup_error_message*/ setup_error_message); RAY_LOG(WARNING) << "Create runtime env failed for task " << task_spec.TaskId() @@ -1163,8 +1225,8 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, }, allocated_instances_serialized_json); } else { - start_worker_process_fn(task_spec, state, dynamic_options, true, "", "", - callback); + start_worker_process_fn( + task_spec, state, dynamic_options, true, "", "", callback); } } } else { @@ -1205,17 +1267,24 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, if (task_spec.HasRuntimeEnv()) { // create runtime env. CreateRuntimeEnv( - task_spec.SerializedRuntimeEnv(), task_spec.JobId(), + task_spec.SerializedRuntimeEnv(), + task_spec.JobId(), [this, start_worker_process_fn, callback, &state, task_spec]( - bool successful, const std::string &serialized_runtime_env_context, + bool successful, + const std::string &serialized_runtime_env_context, const std::string &setup_error_message) { if (successful) { - start_worker_process_fn(task_spec, state, {}, false, + start_worker_process_fn(task_spec, + state, + {}, + false, task_spec.SerializedRuntimeEnv(), - serialized_runtime_env_context, callback); + serialized_runtime_env_context, + callback); } else { process_failed_runtime_env_setup_failed_++; - callback(nullptr, PopWorkerStatus::RuntimeEnvCreationFailed, + callback(nullptr, + PopWorkerStatus::RuntimeEnvCreationFailed, /*runtime_env_setup_error_message*/ setup_error_message); RAY_LOG(WARNING) << "Create runtime env failed for task " << task_spec.TaskId() @@ -1235,7 +1304,8 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, } } -void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, int64_t backlog_size, +void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, + int64_t backlog_size, int64_t num_available_cpus) { // Code path of task that needs a dedicated worker. if ((task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) || @@ -1261,8 +1331,8 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, int64_t bac << backlog_size << " and available CPUs " << num_available_cpus; for (int i = 0; i < num_needed; i++) { PopWorkerStatus status; - StartWorkerProcess(task_spec.GetLanguage(), rpc::WorkerType::WORKER, - task_spec.JobId(), &status); + StartWorkerProcess( + task_spec.GetLanguage(), rpc::WorkerType::WORKER, task_spec.JobId(), &status); } } } @@ -1406,8 +1476,8 @@ void WorkerPool::WarnAboutSize() { << "some discussion of workarounds)."; std::string warning_message_str = warning_message.str(); RAY_LOG(WARNING) << warning_message_str; - auto error_data_ptr = gcs::CreateErrorTableData("worker_pool_large", - warning_message_str, get_time_()); + auto error_data_ptr = gcs::CreateErrorTableData( + "worker_pool_large", warning_message_str, get_time_()); RAY_CHECK_OK(gcs_client_->Errors().AsyncReportJobError(error_data_ptr, nullptr)); } } @@ -1507,14 +1577,18 @@ WorkerPool::IOWorkerState &WorkerPool::GetIOWorkerStateFromWorkerType( } void WorkerPool::CreateRuntimeEnv( - const std::string &serialized_runtime_env, const JobID &job_id, + const std::string &serialized_runtime_env, + const JobID &job_id, const CreateRuntimeEnvCallback &callback, const std::string &serialized_allocated_resource_instances) { // create runtime env. agent_manager_->CreateRuntimeEnv( - job_id, serialized_runtime_env, serialized_allocated_resource_instances, + job_id, + serialized_runtime_env, + serialized_allocated_resource_instances, [job_id, serialized_runtime_env = std::move(serialized_runtime_env), callback]( - bool successful, const std::string &serialized_runtime_env_context, + bool successful, + const std::string &serialized_runtime_env_context, const std::string &setup_error_message) { if (successful) { callback(true, serialized_runtime_env_context, ""); diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 9ee115fc1..eb06b6c14 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -69,9 +69,10 @@ enum PopWorkerStatus { /// RuntimeEnvCreationFailed. /// \return true if the worker was used. Otherwise, return false /// and the worker will be returned to the worker pool. -using PopWorkerCallback = std::function worker, PopWorkerStatus status, - const std::string &runtime_env_setup_error_message)>; +using PopWorkerCallback = + std::function worker, + PopWorkerStatus status, + const std::string &runtime_env_setup_error_message)>; /// \class WorkerPoolInterface /// @@ -101,7 +102,8 @@ class WorkerPoolInterface { /// resource value will be {"CPU":20000}. /// \return Void. virtual void PopWorker( - const TaskSpecification &task_spec, const PopWorkerCallback &callback, + const TaskSpecification &task_spec, + const PopWorkerCallback &callback, const std::string &allocated_instances_serialized_json = "{}") = 0; /// Add an idle worker to the pool. /// @@ -179,16 +181,21 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// \param ray_debugger_external Ray debugger in workers will be started in a way /// that they are accessible from outside the node. /// \param get_time A callback to get the current time. - WorkerPool(instrumented_io_context &io_service, const NodeID node_id, - const std::string node_address, int num_workers_soft_limit, + WorkerPool(instrumented_io_context &io_service, + const NodeID node_id, + const std::string node_address, + int num_workers_soft_limit, int num_initial_python_workers_for_first_job, - int maximum_startup_concurrency, int min_worker_port, int max_worker_port, + int maximum_startup_concurrency, + int min_worker_port, + int max_worker_port, const std::vector &worker_ports, std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands, const std::string &native_library_path, std::function starting_worker_timeout_callback, - int ray_debugger_external, const std::function get_time); + int ray_debugger_external, + const std::function get_time); /// Destructor responsible for freeing a set of workers owned by this class. virtual ~WorkerPool(); @@ -230,7 +237,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// finished/failed. /// Returns 0 if the worker should bind on a random port. /// \return If the registration is successful. - Status RegisterWorker(const std::shared_ptr &worker, pid_t pid, + Status RegisterWorker(const std::shared_ptr &worker, + pid_t pid, StartupToken worker_startup_token, std::function send_reply_callback); @@ -327,7 +335,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { void PushWorker(const std::shared_ptr &worker); /// See interface. - void PopWorker(const TaskSpecification &task_spec, const PopWorkerCallback &callback, + void PopWorker(const TaskSpecification &task_spec, + const PopWorkerCallback &callback, const std::string &allocated_instances_serialized_json = "{}"); /// Try to prestart a number of workers suitable the given task spec. Prestarting @@ -338,7 +347,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// \param backlog_size The number of tasks in the client backlog of this shape. /// \param num_available_cpus The number of CPUs that are currently unused. /// We aim to prestart 1 worker per CPU, up to the the backlog size. - void PrestartWorkers(const TaskSpecification &task_spec, int64_t backlog_size, + void PrestartWorkers(const TaskSpecification &task_spec, + int64_t backlog_size, int64_t num_available_cpus); /// Return the current size of the worker pool for the requested language. Counts only @@ -416,7 +426,9 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// \return The process that we started and a token. If the token is less than 0, /// we didn't start a process. std::tuple StartWorkerProcess( - const Language &language, const rpc::WorkerType worker_type, const JobID &job_id, + const Language &language, + const rpc::WorkerType worker_type, + const JobID &job_id, PopWorkerStatus *status /*output*/, const std::vector &dynamic_options = {}, const int runtime_env_hash = 0, @@ -545,7 +557,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// (due to worker process crash or any other reasons), remove them /// from `starting_worker_processes`. Otherwise if we'll mistakenly /// think there are unregistered workers, and won't start new workers. - void MonitorStartingWorkerProcess(const Process &proc, StartupToken proc_startup_token, + void MonitorStartingWorkerProcess(const Process &proc, + StartupToken proc_startup_token, const Language &language, const rpc::WorkerType worker_type); @@ -621,19 +634,26 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// \param task_id The related task id. void InvokePopWorkerCallbackForProcess( absl::flat_hash_map &workers_to_tasks, - StartupToken startup_token, const std::shared_ptr &worker, - const PopWorkerStatus &status, bool *found /* output */, - bool *worker_used /* output */, TaskID *task_id /* output */); + StartupToken startup_token, + const std::shared_ptr &worker, + const PopWorkerStatus &status, + bool *found /* output */, + bool *worker_used /* output */, + TaskID *task_id /* output */); /// Create runtime env asynchronously by runtime env agent. void CreateRuntimeEnv( - const std::string &serialized_runtime_env, const JobID &job_id, + const std::string &serialized_runtime_env, + const JobID &job_id, const CreateRuntimeEnvCallback &callback, const std::string &serialized_allocated_resource_instances = "{}"); void AddStartingWorkerProcess( - State &state, const int workers_to_start, const rpc::WorkerType worker_type, - const Process &proc, const std::chrono::high_resolution_clock::time_point &start, + State &state, + const int workers_to_start, + const rpc::WorkerType worker_type, + const Process &proc, + const std::chrono::high_resolution_clock::time_point &start, const rpc::RuntimeEnvInfo &runtime_env_info); void RemoveStartingWorkerProcess(State &state, const StartupToken &proc_startup_token); diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 14a06dcdc..c379fa38a 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -121,9 +121,21 @@ class WorkerPoolMock : public WorkerPool { absl::flat_hash_map> &mock_worker_rpc_clients) : WorkerPool( - io_service, NodeID::FromRandom(), "", POOL_SIZE_SOFT_LIMIT, 0, - MAXIMUM_STARTUP_CONCURRENCY, 0, 0, {}, nullptr, worker_commands, "", []() {}, - 0, [this]() { return current_time_ms_; }), + io_service, + NodeID::FromRandom(), + "", + POOL_SIZE_SOFT_LIMIT, + 0, + MAXIMUM_STARTUP_CONCURRENCY, + 0, + 0, + {}, + nullptr, + worker_commands, + "", + []() {}, + 0, + [this]() { return current_time_ms_; }), last_worker_process_(), instrumented_io_service_(io_service), error_message_type_(1), @@ -216,27 +228,38 @@ class WorkerPoolMock : public WorkerPool { } std::shared_ptr CreateWorker( - Process proc, const Language &language = Language::PYTHON, + Process proc, + const Language &language = Language::PYTHON, const JobID &job_id = JOB_ID, const rpc::WorkerType worker_type = rpc::WorkerType::WORKER, - int runtime_env_hash = 0, StartupToken worker_startup_token = 0, + int runtime_env_hash = 0, + StartupToken worker_startup_token = 0, bool set_process = true) { std::function client_handler = [this](ClientConnection &client) { HandleNewClient(client); }; - std::function, int64_t, - const std::vector &)> + std::function, int64_t, const std::vector &)> message_handler = [this](std::shared_ptr client, int64_t message_type, const std::vector &message) { HandleMessage(client, message_type, message); }; local_stream_socket socket(instrumented_io_service_); - auto client = - ClientConnection::Create(client_handler, message_handler, std::move(socket), - "worker", {}, error_message_type_); - std::shared_ptr worker_ = std::make_shared( - job_id, runtime_env_hash, WorkerID::FromRandom(), language, worker_type, - "127.0.0.1", client, client_call_manager_, worker_startup_token); + auto client = ClientConnection::Create(client_handler, + message_handler, + std::move(socket), + "worker", + {}, + error_message_type_); + std::shared_ptr worker_ = std::make_shared(job_id, + runtime_env_hash, + WorkerID::FromRandom(), + language, + worker_type, + "127.0.0.1", + client, + client_call_manager_, + worker_startup_token); std::shared_ptr worker = std::dynamic_pointer_cast(worker_); auto rpc_client = std::make_shared(instrumented_io_service_); @@ -296,12 +319,16 @@ class WorkerPoolMock : public WorkerPool { auto register_workers = num_workers - timeout_worker_number; for (int i = 0; i < register_workers; i++) { auto worker = CreateWorker( - it->first, is_java ? Language::JAVA : Language::PYTHON, JOB_ID, - rpc::WorkerType::WORKER, runtime_env_hash, + it->first, + is_java ? Language::JAVA : Language::PYTHON, + JOB_ID, + rpc::WorkerType::WORKER, + runtime_env_hash, startup_tokens_by_proc_[it->first], // Don't set process to ensure the `RegisterWorker` succeeds below. false); - RAY_CHECK_OK(RegisterWorker(worker, it->first.GetId(), + RAY_CHECK_OK(RegisterWorker(worker, + it->first.GetId(), startup_tokens_by_proc_[it->first], [](Status, int) {})); OnWorkerStarted(worker); @@ -317,26 +344,28 @@ class WorkerPoolMock : public WorkerPool { // \param[in] push_workers If true, tries to push the workers from the started // processes. std::shared_ptr PopWorkerSync( - const TaskSpecification &task_spec, bool push_workers = true, - PopWorkerStatus *worker_status = nullptr, int timeout_worker_number = 0, + const TaskSpecification &task_spec, + bool push_workers = true, + PopWorkerStatus *worker_status = nullptr, + int timeout_worker_number = 0, std::string *runtime_env_error_msg = nullptr) { std::shared_ptr popped_worker = nullptr; std::promise promise; - this->PopWorker( - task_spec, - [&popped_worker, worker_status, &promise, runtime_env_error_msg]( - const std::shared_ptr worker, PopWorkerStatus status, - const std::string &runtime_env_setup_error_message) -> bool { - popped_worker = worker; - if (worker_status != nullptr) { - *worker_status = status; - } - if (runtime_env_error_msg) { - *runtime_env_error_msg = runtime_env_setup_error_message; - } - promise.set_value(true); - return true; - }); + this->PopWorker(task_spec, + [&popped_worker, worker_status, &promise, runtime_env_error_msg]( + const std::shared_ptr worker, + PopWorkerStatus status, + const std::string &runtime_env_setup_error_message) -> bool { + popped_worker = worker; + if (worker_status != nullptr) { + *worker_status = status; + } + if (runtime_env_error_msg) { + *runtime_env_error_msg = runtime_env_setup_error_message; + } + promise.set_value(true); + return true; + }); if (push_workers) { PushWorkers(timeout_worker_number); } @@ -357,7 +386,8 @@ class WorkerPoolMock : public WorkerPool { absl::flat_hash_map> &mock_worker_rpc_clients_; void HandleNewClient(ClientConnection &){}; - void HandleMessage(std::shared_ptr, int64_t, + void HandleMessage(std::shared_ptr, + int64_t, const std::vector &){}; }; @@ -396,17 +426,18 @@ class WorkerPoolTest : public ::testing::Test { } std::shared_ptr CreateSpillWorker(Process proc) { - return worker_pool_->CreateWorker(proc, Language::PYTHON, JobID::Nil(), - rpc::WorkerType::SPILL_WORKER); + return worker_pool_->CreateWorker( + proc, Language::PYTHON, JobID::Nil(), rpc::WorkerType::SPILL_WORKER); } std::shared_ptr CreateRestoreWorker(Process proc) { - return worker_pool_->CreateWorker(proc, Language::PYTHON, JobID::Nil(), - rpc::WorkerType::RESTORE_WORKER); + return worker_pool_->CreateWorker( + proc, Language::PYTHON, JobID::Nil(), rpc::WorkerType::RESTORE_WORKER); } std::shared_ptr RegisterDriver( - const Language &language = Language::PYTHON, const JobID &job_id = JOB_ID, + const Language &language = Language::PYTHON, + const JobID &job_id = JOB_ID, const rpc::JobConfig &job_config = rpc::JobConfig()) { auto driver = worker_pool_->CreateWorker(Process::CreateNewDummy(), Language::PYTHON, job_id); @@ -416,8 +447,8 @@ class WorkerPoolTest : public ::testing::Test { } void SetWorkerCommands(const WorkerCommandMap &worker_commands) { - worker_pool_ = std::make_unique(io_service_, worker_commands, - mock_worker_rpc_clients_); + worker_pool_ = std::make_unique( + io_service_, worker_commands, mock_worker_rpc_clients_); rpc::JobConfig job_config; job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA); RegisterDriver(Language::PYTHON, JOB_ID, job_config); @@ -432,8 +463,8 @@ class WorkerPoolTest : public ::testing::Test { Process last_started_worker_process; for (int i = 0; i < desired_initial_worker_process_count; i++) { PopWorkerStatus status; - worker_pool_->StartWorkerProcess(language, rpc::WorkerType::WORKER, JOB_ID, - &status); + worker_pool_->StartWorkerProcess( + language, rpc::WorkerType::WORKER, JOB_ID, &status); ASSERT_TRUE(worker_pool_->NumWorkerProcessesStarting() <= expected_worker_process_count); Process prev = worker_pool_->LastStartedWorkerProcess(); @@ -443,7 +474,8 @@ class WorkerPoolTest : public ::testing::Test { worker_pool_->GetWorkerCommand(last_started_worker_process); if (language == Language::JAVA) { auto it = std::find( - real_command.begin(), real_command.end(), + real_command.begin(), + real_command.end(), GetNumJavaWorkersPerProcessSystemProperty(num_workers_per_process)); ASSERT_NE(it, real_command.end()); } @@ -475,8 +507,8 @@ class WorkerPoolTest : public ::testing::Test { false); const rpc::RegisterAgentRequest request; rpc::RegisterAgentReply reply; - auto send_reply_callback = [](ray::Status status, std::function f1, - std::function f2) {}; + auto send_reply_callback = + [](ray::Status status, std::function f1, std::function f2) {}; agent_manager->HandleRegisterAgent(request, &reply, send_reply_callback); worker_pool_->SetAgentManager(agent_manager); } @@ -515,8 +547,10 @@ static inline rpc::RuntimeEnvInfo ExampleRuntimeEnvInfoFromString( } static inline TaskSpecification ExampleTaskSpec( - const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON, - const JobID &job_id = JOB_ID, const ActorID actor_creation_id = ActorID::Nil(), + const ActorID actor_id = ActorID::Nil(), + const Language &language = Language::PYTHON, + const JobID &job_id = JOB_ID, + const ActorID actor_creation_id = ActorID::Nil(), const std::vector &dynamic_worker_options = {}, const TaskID &task_id = TaskID::FromRandom(JobID::Nil()), const rpc::RuntimeEnvInfo runtime_env_info = rpc::RuntimeEnvInfo()) { @@ -678,8 +712,8 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { {"-Dmy-actor.hello=foo", "-Dmy-actor.world=bar", "-Xmx2g", "-Xms1g"}); auto task_id = TaskID::ForDriverTask(JOB_ID); auto actor_id = ActorID::Of(JOB_ID, task_id, 1); - TaskSpecification task_spec = ExampleTaskSpec(ActorID::Nil(), Language::JAVA, JOB_ID, - actor_id, actor_jvm_options, task_id); + TaskSpecification task_spec = ExampleTaskSpec( + ActorID::Nil(), Language::JAVA, JOB_ID, actor_id, actor_jvm_options, task_id); rpc::JobConfig job_config = rpc::JobConfig(); job_config.add_code_search_path("/test/code_search_path"); @@ -710,8 +744,8 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { expected_command.push_back("-Dray.raylet.startup-token=0"); expected_command.push_back("-Dray.internal.runtime-env-hash=1"); // User-defined per-process options - expected_command.insert(expected_command.end(), actor_jvm_options.begin(), - actor_jvm_options.end()); + expected_command.insert( + expected_command.end(), actor_jvm_options.begin(), actor_jvm_options.end()); // Entry point expected_command.push_back("MainClass"); ASSERT_EQ(real_command, expected_command); @@ -734,13 +768,15 @@ TEST_F(WorkerPoolTest, PopWorkerMultiTenancy) { // Make the first worker an actor worker. if (i == 0) { auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); - auto task_spec = ExampleTaskSpec(/*actor_id=*/ActorID::Nil(), Language::PYTHON, - job_id, actor_creation_id); + auto task_spec = ExampleTaskSpec( + /*actor_id=*/ActorID::Nil(), Language::PYTHON, job_id, actor_creation_id); runtime_env_hash = task_spec.GetRuntimeEnvHash(); } - auto worker = - worker_pool_->CreateWorker(Process::CreateNewDummy(), Language::PYTHON, job_id, - rpc::WorkerType::WORKER, runtime_env_hash); + auto worker = worker_pool_->CreateWorker(Process::CreateNewDummy(), + Language::PYTHON, + job_id, + rpc::WorkerType::WORKER, + runtime_env_hash); worker_pool_->PushWorker(worker); } } @@ -752,8 +788,8 @@ TEST_F(WorkerPoolTest, PopWorkerMultiTenancy) { for (auto job_id : job_ids) { auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); // Pop workers for actor creation tasks. - auto task_spec = ExampleTaskSpec(/*actor_id=*/ActorID::Nil(), Language::PYTHON, - job_id, actor_creation_id); + auto task_spec = ExampleTaskSpec( + /*actor_id=*/ActorID::Nil(), Language::PYTHON, job_id, actor_creation_id); auto worker = worker_pool_->PopWorkerSync(task_spec); ASSERT_TRUE(worker); ASSERT_EQ(worker->GetAssignedJobId(), job_id); @@ -791,7 +827,8 @@ TEST_F(WorkerPoolTest, MaximumStartupConcurrency) { for (int i = 0; i < MAXIMUM_STARTUP_CONCURRENCY; i++) { worker_pool_->PopWorker( task_spec, - [](const std::shared_ptr worker, PopWorkerStatus status, + [](const std::shared_ptr worker, + PopWorkerStatus status, const std::string &runtime_env_setup_error_message) -> bool { return true; }); auto last_process = worker_pool_->LastStartedWorkerProcess(); RAY_CHECK(last_process.IsValid()); @@ -802,7 +839,8 @@ TEST_F(WorkerPoolTest, MaximumStartupConcurrency) { ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting()); worker_pool_->PopWorker( task_spec, - [](const std::shared_ptr worker, PopWorkerStatus status, + [](const std::shared_ptr worker, + PopWorkerStatus status, const std::string &runtime_env_setup_error_message) -> bool { return true; }); ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting()); @@ -811,9 +849,9 @@ TEST_F(WorkerPoolTest, MaximumStartupConcurrency) { for (const auto &process : started_processes) { auto worker = worker_pool_->CreateWorker(Process()); worker->SetStartupToken(worker_pool_->GetStartupToken(process)); - RAY_CHECK_OK(worker_pool_->RegisterWorker(worker, process.GetId(), - worker_pool_->GetStartupToken(process), - [](Status, int) {})); + RAY_CHECK_OK(worker_pool_->RegisterWorker( + worker, process.GetId(), worker_pool_->GetStartupToken(process), [](Status, int) { + })); // Calling `RegisterWorker` won't affect the counter of starting worker processes. ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting()); workers.push_back(worker); @@ -823,7 +861,8 @@ TEST_F(WorkerPoolTest, MaximumStartupConcurrency) { ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting()); worker_pool_->PopWorker( task_spec, - [](const std::shared_ptr worker, PopWorkerStatus status, + [](const std::shared_ptr worker, + PopWorkerStatus status, const std::string &runtime_env_setup_error_message) -> bool { return true; }); ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting()); @@ -1326,9 +1365,9 @@ TEST_F(WorkerPoolTest, TestWorkerCappingWithExitDelay) { auto worker = worker_pool_->CreateWorker(Process(), language); worker->SetStartupToken(worker_pool_->GetStartupToken(proc)); workers.push_back(worker); - RAY_CHECK_OK(worker_pool_->RegisterWorker(worker, proc.GetId(), - worker_pool_->GetStartupToken(proc), - [](Status, int) {})); + RAY_CHECK_OK(worker_pool_->RegisterWorker( + worker, proc.GetId(), worker_pool_->GetStartupToken(proc), [](Status, int) { + })); worker_pool_->OnWorkerStarted(worker); ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), worker); worker_pool_->PushWorker(worker); @@ -1388,12 +1427,20 @@ TEST_F(WorkerPoolTest, TestWorkerCappingWithExitDelay) { TEST_F(WorkerPoolTest, PopWorkerWithRuntimeEnv) { ASSERT_EQ(worker_pool_->GetProcessSize(), 0); auto actor_creation_id = ActorID::Of(JOB_ID, TaskID::ForDriverTask(JOB_ID), 1); - const auto actor_creation_task_spec = ExampleTaskSpec( - ActorID::Nil(), Language::PYTHON, JOB_ID, actor_creation_id, {"XXX=YYY"}, - TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfo({"XXX"})); - const auto normal_task_spec = ExampleTaskSpec( - ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), {"XXX=YYY"}, - TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfo({"XXX"})); + const auto actor_creation_task_spec = ExampleTaskSpec(ActorID::Nil(), + Language::PYTHON, + JOB_ID, + actor_creation_id, + {"XXX=YYY"}, + TaskID::FromRandom(JobID::Nil()), + ExampleRuntimeEnvInfo({"XXX"})); + const auto normal_task_spec = ExampleTaskSpec(ActorID::Nil(), + Language::PYTHON, + JOB_ID, + ActorID::Nil(), + {"XXX=YYY"}, + TaskID::FromRandom(JobID::Nil()), + ExampleRuntimeEnvInfo({"XXX"})); const auto normal_task_spec_without_runtime_env = ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), {}); // Pop worker for actor creation task again. @@ -1465,14 +1512,23 @@ TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWorkerLevel) { // Start actor with runtime env. auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); const auto actor_creation_task_spec = - ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, actor_creation_id, - {"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), runtime_env_info); + ExampleTaskSpec(ActorID::Nil(), + Language::PYTHON, + job_id, + actor_creation_id, + {"XXX=YYY"}, + TaskID::FromRandom(JobID::Nil()), + runtime_env_info); auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); ASSERT_EQ(valid_uris.size(), 1); // Start task with runtime env. - const auto normal_task_spec = - ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(), - {"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), runtime_env_info); + const auto normal_task_spec = ExampleTaskSpec(ActorID::Nil(), + Language::PYTHON, + job_id, + ActorID::Nil(), + {"XXX=YYY"}, + TaskID::FromRandom(JobID::Nil()), + runtime_env_info); auto popped_normal_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); ASSERT_EQ(valid_uris.size(), 1); // Disconnect actor worker. @@ -1499,8 +1555,13 @@ TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWorkerLevel) { // Start actor with runtime env. auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 2); const auto actor_creation_task_spec = - ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, actor_creation_id, - {"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), runtime_env_info); + ExampleTaskSpec(ActorID::Nil(), + Language::PYTHON, + job_id, + actor_creation_id, + {"XXX=YYY"}, + TaskID::FromRandom(JobID::Nil()), + runtime_env_info); auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); ASSERT_EQ(valid_uris.size(), 1); // Start task with runtime env. @@ -1536,8 +1597,13 @@ TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWithMultipleWorkers) { for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) { auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), i + 1); const auto actor_creation_task_spec = - ExampleTaskSpec(ActorID::Nil(), Language::JAVA, job_id, actor_creation_id, {}, - TaskID::FromRandom(JobID::Nil()), runtime_env_info); + ExampleTaskSpec(ActorID::Nil(), + Language::JAVA, + job_id, + actor_creation_id, + {}, + TaskID::FromRandom(JobID::Nil()), + runtime_env_info); auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); ASSERT_NE(popped_actor_worker, nullptr); workers.push_back(popped_actor_worker); @@ -1558,8 +1624,13 @@ TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWithMultipleWorkers) { // process. auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); const auto actor_creation_task_spec = - ExampleTaskSpec(ActorID::Nil(), Language::JAVA, job_id, actor_creation_id, {}, - TaskID::FromRandom(JobID::Nil()), runtime_env_info); + ExampleTaskSpec(ActorID::Nil(), + Language::JAVA, + job_id, + actor_creation_id, + {}, + TaskID::FromRandom(JobID::Nil()), + runtime_env_info); PopWorkerStatus status; // Only one worker registration. All the other worker registration times out. auto popped_actor_worker = worker_pool_->PopWorkerSync( @@ -1588,24 +1659,38 @@ TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) { ASSERT_EQ(worker_pool_->GetProcessSize(), 0); auto actor_creation_id = ActorID::Of(JOB_ID, TaskID::ForDriverTask(JOB_ID), 1); const auto actor_creation_task_spec_1 = - ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, actor_creation_id, - /*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), + ExampleTaskSpec(ActorID::Nil(), + Language::PYTHON, + JOB_ID, + actor_creation_id, + /*dynamic_options=*/{}, + TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfoFromString("mock_runtime_env_1")); const auto task_spec_1 = - ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), - /*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), + ExampleTaskSpec(ActorID::Nil(), + Language::PYTHON, + JOB_ID, + ActorID::Nil(), + /*dynamic_options=*/{}, + TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfoFromString("mock_runtime_env_1")); const auto task_spec_2 = - ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), - /*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), + ExampleTaskSpec(ActorID::Nil(), + Language::PYTHON, + JOB_ID, + ActorID::Nil(), + /*dynamic_options=*/{}, + TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfoFromString("mock_runtime_env_2")); const int runtime_env_hash_1 = actor_creation_task_spec_1.GetRuntimeEnvHash(); // Push worker with runtime env 1. - auto worker = - worker_pool_->CreateWorker(Process::CreateNewDummy(), Language::PYTHON, JOB_ID, - rpc::WorkerType::WORKER, runtime_env_hash_1); + auto worker = worker_pool_->CreateWorker(Process::CreateNewDummy(), + Language::PYTHON, + JOB_ID, + rpc::WorkerType::WORKER, + runtime_env_hash_1); worker_pool_->PushWorker(worker); // Try to pop worker for task with runtime env 2. @@ -1624,8 +1709,11 @@ TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) { } // Push another worker with runtime env 1. - worker = worker_pool_->CreateWorker(Process::CreateNewDummy(), Language::PYTHON, JOB_ID, - rpc::WorkerType::WORKER, runtime_env_hash_1); + worker = worker_pool_->CreateWorker(Process::CreateNewDummy(), + Language::PYTHON, + JOB_ID, + rpc::WorkerType::WORKER, + runtime_env_hash_1); worker_pool_->PushWorker(worker); // Try to pop the worker for an actor with runtime env 1. @@ -1640,13 +1728,13 @@ TEST_F(WorkerPoolTest, WorkerNoLeaks) { const auto task_spec = ExampleTaskSpec(); // Pop a worker and don't dispatch. - worker_pool_->PopWorker( - task_spec, - [](const std::shared_ptr worker, PopWorkerStatus status, - const std::string &runtime_env_setup_error_message) -> bool { - // Don't dispatch this worker. - return false; - }); + worker_pool_->PopWorker(task_spec, + [](const std::shared_ptr worker, + PopWorkerStatus status, + const std::string &runtime_env_setup_error_message) -> bool { + // Don't dispatch this worker. + return false; + }); // One worker process has been started. ASSERT_EQ(worker_pool_->GetProcessSize(), 1); // No idle workers because no workers pushed. @@ -1656,24 +1744,24 @@ TEST_F(WorkerPoolTest, WorkerNoLeaks) { // The worker has been pushed but not dispatched. ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), 1); // Pop a worker and don't dispatch. - worker_pool_->PopWorker( - task_spec, - [](const std::shared_ptr worker, PopWorkerStatus status, - const std::string &runtime_env_setup_error_message) -> bool { - // Don't dispatch this worker. - return false; - }); + worker_pool_->PopWorker(task_spec, + [](const std::shared_ptr worker, + PopWorkerStatus status, + const std::string &runtime_env_setup_error_message) -> bool { + // Don't dispatch this worker. + return false; + }); // The worker is popped but not dispatched. ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), 1); ASSERT_EQ(worker_pool_->GetProcessSize(), 1); // Pop a worker and dispatch. - worker_pool_->PopWorker( - task_spec, - [](const std::shared_ptr worker, PopWorkerStatus status, - const std::string &runtime_env_setup_error_message) -> bool { - // Dispatch this worker. - return true; - }); + worker_pool_->PopWorker(task_spec, + [](const std::shared_ptr worker, + PopWorkerStatus status, + const std::string &runtime_env_setup_error_message) -> bool { + // Dispatch this worker. + return true; + }); // The worker is popped and dispatched. ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), 0); ASSERT_EQ(worker_pool_->GetProcessSize(), 1); @@ -1690,7 +1778,8 @@ TEST_F(WorkerPoolTest, PopWorkerStatus) { auto task_spec = ExampleTaskSpec(); worker_pool_->PopWorker( task_spec, - [](const std::shared_ptr worker, PopWorkerStatus status, + [](const std::shared_ptr worker, + PopWorkerStatus status, const std::string &runtime_env_setup_error_message) -> bool { return true; }); } ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting()); @@ -1726,21 +1815,31 @@ TEST_F(WorkerPoolTest, PopWorkerStatus) { /* Test PopWorkerStatus RuntimeEnvCreationFailed */ // Create a task with bad runtime env. - const auto task_spec_with_bad_runtime_env = ExampleTaskSpec( - ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(), {"XXX=YYY"}, - TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfoFromString(BAD_RUNTIME_ENV)); + const auto task_spec_with_bad_runtime_env = + ExampleTaskSpec(ActorID::Nil(), + Language::PYTHON, + job_id, + ActorID::Nil(), + {"XXX=YYY"}, + TaskID::FromRandom(JobID::Nil()), + ExampleRuntimeEnvInfoFromString(BAD_RUNTIME_ENV)); std::string error_msg; - popped_worker = worker_pool_->PopWorkerSync(task_spec_with_bad_runtime_env, true, - &status, 0, &error_msg); + popped_worker = worker_pool_->PopWorkerSync( + task_spec_with_bad_runtime_env, true, &status, 0, &error_msg); // PopWorker failed and the status is `RuntimeEnvCreationFailed`. ASSERT_EQ(popped_worker, nullptr); ASSERT_EQ(status, PopWorkerStatus::RuntimeEnvCreationFailed); ASSERT_EQ(error_msg, BAD_RUNTIME_ENV_ERROR_MSG); // Create a task with available runtime env. - const auto task_spec_with_runtime_env = ExampleTaskSpec( - ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(), {"XXX=YYY"}, - TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfo({"XXX"})); + const auto task_spec_with_runtime_env = + ExampleTaskSpec(ActorID::Nil(), + Language::PYTHON, + job_id, + ActorID::Nil(), + {"XXX=YYY"}, + TaskID::FromRandom(JobID::Nil()), + ExampleRuntimeEnvInfo({"XXX"})); popped_worker = worker_pool_->PopWorkerSync(task_spec_with_runtime_env, true, &status); // PopWorker success. ASSERT_NE(popped_worker, nullptr); diff --git a/src/ray/raylet_client/raylet_client.cc b/src/ray/raylet_client/raylet_client.cc index b932735ba..b075c629f 100644 --- a/src/ray/raylet_client/raylet_client.cc +++ b/src/ray/raylet_client/raylet_client.cc @@ -27,9 +27,11 @@ using MessageType = ray::protocol::MessageType; namespace { inline flatbuffers::Offset to_flatbuf( flatbuffers::FlatBufferBuilder &fbb, const ray::rpc::Address &address) { - return ray::protocol::CreateAddress( - fbb, fbb.CreateString(address.raylet_id()), fbb.CreateString(address.ip_address()), - address.port(), fbb.CreateString(address.worker_id())); + return ray::protocol::CreateAddress(fbb, + fbb.CreateString(address.raylet_id()), + fbb.CreateString(address.ip_address()), + address.port(), + fbb.CreateString(address.worker_id())); } flatbuffers::Offset>> @@ -49,7 +51,8 @@ namespace ray { raylet::RayletConnection::RayletConnection(instrumented_io_context &io_service, const std::string &raylet_socket, - int num_retries, int64_t timeout) { + int num_retries, + int64_t timeout) { local_stream_socket socket(io_service); Status s = ConnectSocketRetry(socket, raylet_socket, num_retries, timeout); // If we could not connect to the socket, exit. @@ -97,21 +100,35 @@ raylet::RayletClient::RayletClient( raylet::RayletClient::RayletClient( instrumented_io_context &io_service, std::shared_ptr grpc_client, - const std::string &raylet_socket, const WorkerID &worker_id, - rpc::WorkerType worker_type, const JobID &job_id, const int &runtime_env_hash, - const Language &language, const std::string &ip_address, Status *status, - NodeID *raylet_id, int *port, std::string *serialized_job_config, + const std::string &raylet_socket, + const WorkerID &worker_id, + rpc::WorkerType worker_type, + const JobID &job_id, + const int &runtime_env_hash, + const Language &language, + const std::string &ip_address, + Status *status, + NodeID *raylet_id, + int *port, + std::string *serialized_job_config, StartupToken startup_token) : grpc_client_(std::move(grpc_client)), worker_id_(worker_id), job_id_(job_id) { conn_ = std::make_unique(io_service, raylet_socket, -1, -1); flatbuffers::FlatBufferBuilder fbb; // TODO(suquark): Use `WorkerType` in `common.proto` without converting to int. - auto message = protocol::CreateRegisterClientRequest( - fbb, static_cast(worker_type), to_flatbuf(fbb, worker_id), getpid(), - startup_token, to_flatbuf(fbb, job_id), runtime_env_hash, language, - fbb.CreateString(ip_address), - /*port=*/0, fbb.CreateString(*serialized_job_config)); + auto message = + protocol::CreateRegisterClientRequest(fbb, + static_cast(worker_type), + to_flatbuf(fbb, worker_id), + getpid(), + startup_token, + to_flatbuf(fbb, job_id), + runtime_env_hash, + language, + fbb.CreateString(ip_address), + /*port=*/0, + fbb.CreateString(*serialized_job_config)); fbb.Finish(message); // Register the process ID with the raylet. // NOTE(swang): If raylet exits and we are registered as a worker, we will get killed. @@ -186,14 +203,20 @@ Status raylet::RayletClient::TaskDone() { Status raylet::RayletClient::FetchOrReconstruct( const std::vector &object_ids, - const std::vector &owner_addresses, bool fetch_only, - bool mark_worker_blocked, const TaskID ¤t_task_id) { + const std::vector &owner_addresses, + bool fetch_only, + bool mark_worker_blocked, + const TaskID ¤t_task_id) { RAY_CHECK(object_ids.size() == owner_addresses.size()); flatbuffers::FlatBufferBuilder fbb; auto object_ids_message = to_flatbuf(fbb, object_ids); - auto message = protocol::CreateFetchOrReconstruct( - fbb, object_ids_message, AddressesToFlatbuffer(fbb, owner_addresses), fetch_only, - mark_worker_blocked, to_flatbuf(fbb, current_task_id)); + auto message = + protocol::CreateFetchOrReconstruct(fbb, + object_ids_message, + AddressesToFlatbuffer(fbb, owner_addresses), + fetch_only, + mark_worker_blocked, + to_flatbuf(fbb, current_task_id)); fbb.Finish(message); return conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb); } @@ -221,19 +244,24 @@ Status raylet::RayletClient::NotifyDirectCallTaskUnblocked() { Status raylet::RayletClient::Wait(const std::vector &object_ids, const std::vector &owner_addresses, - int num_returns, int64_t timeout_milliseconds, - bool mark_worker_blocked, const TaskID ¤t_task_id, + int num_returns, + int64_t timeout_milliseconds, + bool mark_worker_blocked, + const TaskID ¤t_task_id, WaitResultPair *result) { // Write request. flatbuffers::FlatBufferBuilder fbb; - auto message = protocol::CreateWaitRequest( - fbb, to_flatbuf(fbb, object_ids), AddressesToFlatbuffer(fbb, owner_addresses), - num_returns, timeout_milliseconds, mark_worker_blocked, - to_flatbuf(fbb, current_task_id)); + auto message = protocol::CreateWaitRequest(fbb, + to_flatbuf(fbb, object_ids), + AddressesToFlatbuffer(fbb, owner_addresses), + num_returns, + timeout_milliseconds, + mark_worker_blocked, + to_flatbuf(fbb, current_task_id)); fbb.Finish(message); std::vector reply; - RAY_RETURN_NOT_OK(conn_->AtomicRequestReply(MessageType::WaitRequest, - MessageType::WaitReply, &reply, &fbb)); + RAY_RETURN_NOT_OK(conn_->AtomicRequestReply( + MessageType::WaitRequest, MessageType::WaitReply, &reply, &fbb)); // Parse the flatbuffer object. auto reply_message = flatbuffers::GetRoot(reply.data()); auto found = reply_message->found(); @@ -264,13 +292,16 @@ Status raylet::RayletClient::WaitForDirectActorCallArgs( return conn_->WriteMessage(MessageType::WaitForDirectActorCallArgsRequest, &fbb); } -Status raylet::RayletClient::PushError(const JobID &job_id, const std::string &type, +Status raylet::RayletClient::PushError(const JobID &job_id, + const std::string &type, const std::string &error_message, double timestamp) { flatbuffers::FlatBufferBuilder fbb; - auto message = protocol::CreatePushErrorRequest( - fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), - fbb.CreateString(error_message), timestamp); + auto message = protocol::CreatePushErrorRequest(fbb, + to_flatbuf(fbb, job_id), + fbb.CreateString(type), + fbb.CreateString(error_message), + timestamp); fbb.Finish(message); return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb); } @@ -285,9 +316,11 @@ Status raylet::RayletClient::FreeObjects(const std::vector &object_ids } void raylet::RayletClient::RequestWorkerLease( - const rpc::TaskSpec &task_spec, bool grant_or_reject, + const rpc::TaskSpec &task_spec, + bool grant_or_reject, const rpc::ClientCallback &callback, - const int64_t backlog_size, const bool is_selected_based_on_locality) { + const int64_t backlog_size, + const bool is_selected_based_on_locality) { google::protobuf::Arena arena; auto request = google::protobuf::Arena::CreateMessage(&arena); @@ -326,8 +359,10 @@ void raylet::RayletClient::ReportWorkerBacklog( }); } -Status raylet::RayletClient::ReturnWorker(int worker_port, const WorkerID &worker_id, - bool disconnect_worker, bool worker_exiting) { +Status raylet::RayletClient::ReturnWorker(int worker_port, + const WorkerID &worker_id, + bool disconnect_worker, + bool worker_exiting) { rpc::ReturnWorkerRequest request; request.set_worker_port(worker_port); request.set_worker_id(worker_id.Binary()); @@ -425,7 +460,8 @@ void raylet::RayletClient::ReleaseUnusedBundles( } void raylet::RayletClient::PinObjectIDs( - const rpc::Address &caller_address, const std::vector &object_ids, + const rpc::Address &caller_address, + const std::vector &object_ids, const rpc::ClientCallback &callback) { rpc::PinObjectIDsRequest request; request.mutable_owner_address()->CopyFrom(caller_address); @@ -442,7 +478,8 @@ void raylet::RayletClient::PinObjectIDs( } void raylet::RayletClient::ShutdownRaylet( - const NodeID &node_id, bool graceful, + const NodeID &node_id, + bool graceful, const rpc::ClientCallback &callback) { rpc::ShutdownRayletRequest request; request.set_graceful(graceful); @@ -473,8 +510,8 @@ void raylet::RayletClient::RequestResourceReport( void raylet::RayletClient::SubscribeToPlasma(const ObjectID &object_id, const rpc::Address &owner_address) { flatbuffers::FlatBufferBuilder fbb; - auto message = protocol::CreateSubscribePlasmaReady(fbb, to_flatbuf(fbb, object_id), - to_flatbuf(fbb, owner_address)); + auto message = protocol::CreateSubscribePlasmaReady( + fbb, to_flatbuf(fbb, object_id), to_flatbuf(fbb, owner_address)); fbb.Finish(message); RAY_CHECK_OK(conn_->WriteMessage(MessageType::SubscribePlasmaReady, &fbb)); diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index 66a8a8b47..f286f3861 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -51,7 +51,8 @@ class PinObjectsInterface { public: /// Request to a raylet to pin a plasma object. The callback will be sent via gRPC. virtual void PinObjectIDs( - const rpc::Address &caller_address, const std::vector &object_ids, + const rpc::Address &caller_address, + const std::vector &object_ids, const ray::rpc::ClientCallback &callback) = 0; virtual ~PinObjectsInterface(){}; @@ -67,7 +68,8 @@ class WorkerLeaseInterface { /// \param callback: The callback to call when the request finishes. /// \param backlog_size The queue length for the given shape on the CoreWorker. virtual void RequestWorkerLease( - const rpc::TaskSpec &task_spec, bool grant_or_reject, + const rpc::TaskSpec &task_spec, + bool grant_or_reject, const ray::rpc::ClientCallback &callback, const int64_t backlog_size = -1, const bool is_selected_based_on_locality = false) = 0; @@ -78,8 +80,10 @@ class WorkerLeaseInterface { /// \param disconnect_worker Whether the raylet should disconnect the worker. /// \param worker_exiting Whether the worker is exiting and cannot be reused. /// \return ray::Status - virtual ray::Status ReturnWorker(int worker_port, const WorkerID &worker_id, - bool disconnect_worker, bool worker_exiting) = 0; + virtual ray::Status ReturnWorker(int worker_port, + const WorkerID &worker_id, + bool disconnect_worker, + bool worker_exiting) = 0; /// Notify raylets to release unused workers. /// \param workers_in_use Workers currently in use. @@ -182,7 +186,8 @@ class RayletClientInterface : public PinObjectsInterface, const rpc::ClientCallback &callback) = 0; virtual void ShutdownRaylet( - const NodeID &node_id, bool graceful, + const NodeID &node_id, + bool graceful, const rpc::ClientCallback &callback) = 0; }; @@ -199,13 +204,16 @@ class RayletConnection { /// \param job_id The ID of the driver. This is non-nil if the client is a /// driver. /// \return The connection information. - RayletConnection(instrumented_io_context &io_service, const std::string &raylet_socket, - int num_retries, int64_t timeout); + RayletConnection(instrumented_io_context &io_service, + const std::string &raylet_socket, + int num_retries, + int64_t timeout); ray::Status WriteMessage(MessageType type, flatbuffers::FlatBufferBuilder *fbb = nullptr); - ray::Status AtomicRequestReply(MessageType request_type, MessageType reply_type, + ray::Status AtomicRequestReply(MessageType request_type, + MessageType reply_type, std::vector *reply_message, flatbuffers::FlatBufferBuilder *fbb = nullptr); @@ -246,11 +254,18 @@ class RayletClient : public RayletClientInterface { /// it during startup as a command line argument. RayletClient(instrumented_io_context &io_service, std::shared_ptr grpc_client, - const std::string &raylet_socket, const WorkerID &worker_id, - rpc::WorkerType worker_type, const JobID &job_id, - const int &runtime_env_hash, const Language &language, - const std::string &ip_address, Status *status, NodeID *raylet_id, - int *port, std::string *serialized_job_config, StartupToken startup_token); + const std::string &raylet_socket, + const WorkerID &worker_id, + rpc::WorkerType worker_type, + const JobID &job_id, + const int &runtime_env_hash, + const Language &language, + const std::string &ip_address, + Status *status, + NodeID *raylet_id, + int *port, + std::string *serialized_job_config, + StartupToken startup_token); /// Connect to the raylet via grpc only. /// @@ -287,7 +302,8 @@ class RayletClient : public RayletClientInterface { /// \return int 0 means correct, other numbers mean error. ray::Status FetchOrReconstruct(const std::vector &object_ids, const std::vector &owner_addresses, - bool fetch_only, bool mark_worker_blocked, + bool fetch_only, + bool mark_worker_blocked, const TaskID ¤t_task_id); /// Notify the raylet that this client (worker) is no longer blocked. @@ -322,9 +338,12 @@ class RayletClient : public RayletClientInterface { /// found, and the second element the objects that were not found. /// \return ray::Status. ray::Status Wait(const std::vector &object_ids, - const std::vector &owner_addresses, int num_returns, - int64_t timeout_milliseconds, bool mark_worker_blocked, - const TaskID ¤t_task_id, WaitResultPair *result); + const std::vector &owner_addresses, + int num_returns, + int64_t timeout_milliseconds, + bool mark_worker_blocked, + const TaskID ¤t_task_id, + WaitResultPair *result); /// Wait for the given objects, asynchronously. The core worker is notified when /// the wait completes. @@ -342,8 +361,10 @@ class RayletClient : public RayletClientInterface { /// \param The error message. /// \param The timestamp of the error. /// \return ray::Status. - ray::Status PushError(const ray::JobID &job_id, const std::string &type, - const std::string &error_message, double timestamp); + ray::Status PushError(const ray::JobID &job_id, + const std::string &type, + const std::string &error_message, + double timestamp); /// Free a list of objects from object stores. /// @@ -363,13 +384,17 @@ class RayletClient : public RayletClientInterface { /// Implements WorkerLeaseInterface. void RequestWorkerLease( - const rpc::TaskSpec &resource_spec, bool grant_or_reject, + const rpc::TaskSpec &resource_spec, + bool grant_or_reject, const ray::rpc::ClientCallback &callback, - const int64_t backlog_size, const bool is_selected_based_on_locality) override; + const int64_t backlog_size, + const bool is_selected_based_on_locality) override; /// Implements WorkerLeaseInterface. - ray::Status ReturnWorker(int worker_port, const WorkerID &worker_id, - bool disconnect_worker, bool worker_exiting) override; + ray::Status ReturnWorker(int worker_port, + const WorkerID &worker_id, + bool disconnect_worker, + bool worker_exiting) override; /// Implements WorkerLeaseInterface. void ReportWorkerBacklog( @@ -409,11 +434,13 @@ class RayletClient : public RayletClientInterface { const rpc::ClientCallback &callback) override; void PinObjectIDs( - const rpc::Address &caller_address, const std::vector &object_ids, + const rpc::Address &caller_address, + const std::vector &object_ids, const ray::rpc::ClientCallback &callback) override; void ShutdownRaylet( - const NodeID &node_id, bool graceful, + const NodeID &node_id, + bool graceful, const rpc::ClientCallback &callback) override; void GetSystemConfig( diff --git a/src/ray/rpc/agent_manager/agent_manager_client.h b/src/ray/rpc/agent_manager/agent_manager_client.h index 8c9e36a42..bbd692265 100644 --- a/src/ray/rpc/agent_manager/agent_manager_client.h +++ b/src/ray/rpc/agent_manager/agent_manager_client.h @@ -29,17 +29,20 @@ class AgentManagerClient { /// \param[in] address Address of the agent manager server. /// \param[in] port Port of the agent manager server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - AgentManagerClient(const std::string &address, const int port, + AgentManagerClient(const std::string &address, + const int port, ClientCallManager &client_call_manager) { - grpc_client_ = std::make_unique>(address, port, - client_call_manager); + grpc_client_ = std::make_unique>( + address, port, client_call_manager); }; /// Register agent service to the agent manager server /// /// \param request The request message /// \param callback The callback function that handles reply - VOID_RPC_CLIENT_METHOD(AgentManagerService, RegisterAgent, grpc_client_, + VOID_RPC_CLIENT_METHOD(AgentManagerService, + RegisterAgent, + grpc_client_, /*method_timeout_ms*/ -1, ) private: diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index 6043e6f1c..e63a1d585 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -167,7 +167,8 @@ class ClientCallTag { /// \tparam Reply Type of the reply message. template using PrepareAsyncFunction = std::unique_ptr> ( - GrpcService::Stub::*)(grpc::ClientContext *context, const Request &request, + GrpcService::Stub::*)(grpc::ClientContext *context, + const Request &request, grpc::CompletionQueue *cq); /// `ClientCallManager` is used to manage outgoing gRPC requests and the lifecycles of @@ -183,7 +184,8 @@ class ClientCallManager { /// /// \param[in] main_service The main event loop, to which the callback functions will be /// posted. - explicit ClientCallManager(instrumented_io_context &main_service, int num_threads = 1, + explicit ClientCallManager(instrumented_io_context &main_service, + int num_threads = 1, int64_t call_timeout_ms = -1) : main_service_(main_service), num_threads_(num_threads), @@ -194,8 +196,8 @@ class ClientCallManager { cqs_.reserve(num_threads_); for (int i = 0; i < num_threads_; i++) { cqs_.push_back(std::make_unique()); - polling_threads_.emplace_back(&ClientCallManager::PollEventsFromCompletionQueue, - this, i); + polling_threads_.emplace_back( + &ClientCallManager::PollEventsFromCompletionQueue, this, i); } } @@ -229,14 +231,16 @@ class ClientCallManager { std::shared_ptr CreateCall( typename GrpcService::Stub &stub, const PrepareAsyncFunction prepare_async_function, - const Request &request, const ClientCallback &callback, - std::string call_name, int64_t method_timeout_ms = -1) { + const Request &request, + const ClientCallback &callback, + std::string call_name, + int64_t method_timeout_ms = -1) { auto stats_handle = main_service_.stats().RecordStart(call_name); if (method_timeout_ms == -1) { method_timeout_ms = call_timeout_ms_; } - auto call = std::make_shared>(callback, std::move(stats_handle), - method_timeout_ms); + auto call = std::make_shared>( + callback, std::move(stats_handle), method_timeout_ms); // Send request. // Find the next completion queue to wait for response. call->response_reader_ = (stub.*prepare_async_function)( diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 22e60c6d6..713524280 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -78,8 +78,8 @@ class Executor { /// /// Currently, SyncMETHOD will copy the reply additionally. /// TODO(sang): Fix it. -#define VOID_GCS_RPC_CLIENT_METHOD(SERVICE, METHOD, grpc_client, method_timeout_ms, \ - SPECS) \ +#define VOID_GCS_RPC_CLIENT_METHOD( \ + SERVICE, METHOD, grpc_client, method_timeout_ms, SPECS) \ void METHOD(const METHOD##Request &request, \ const ClientCallback &callback, \ const int64_t timeout_ms = method_timeout_ms) SPECS { \ @@ -102,15 +102,20 @@ class Executor { executor->Retry(); \ } \ }; \ - auto operation = [request, operation_callback, \ - timeout_ms](GcsRpcClient *gcs_rpc_client) { \ - RAY_UNUSED(INVOKE_RPC_CALL(SERVICE, METHOD, request, operation_callback, \ - gcs_rpc_client->grpc_client, timeout_ms)); \ - }; \ + auto operation = \ + [request, operation_callback, timeout_ms](GcsRpcClient *gcs_rpc_client) { \ + RAY_UNUSED(INVOKE_RPC_CALL(SERVICE, \ + METHOD, \ + request, \ + operation_callback, \ + gcs_rpc_client->grpc_client, \ + timeout_ms)); \ + }; \ executor->Execute(operation); \ } \ \ - ray::Status Sync##METHOD(const METHOD##Request &request, METHOD##Reply *reply_in, \ + ray::Status Sync##METHOD(const METHOD##Request &request, \ + METHOD##Reply *reply_in, \ const int64_t timeout_ms = method_timeout_ms) { \ std::promise promise; \ METHOD( \ @@ -133,13 +138,16 @@ class GcsRpcClient { /// \param[in] gcs_service_failure_detected The function is used to redo subscription /// and reconnect to GCS RPC server when gcs service failure is detected. GcsRpcClient( - const std::string &address, const int port, ClientCallManager &client_call_manager, + const std::string &address, + const int port, + ClientCallManager &client_call_manager, std::function gcs_service_failure_detected = nullptr) : gcs_service_failure_detected_(std::move(gcs_service_failure_detected)) { Reset(address, port, client_call_manager); }; - void Reset(const std::string &address, const int port, + void Reset(const std::string &address, + const int port, ClientCallManager &client_call_manager) { job_info_grpc_client_ = std::make_unique>( address, port, client_call_manager); @@ -148,8 +156,8 @@ class GcsRpcClient { node_info_grpc_client_ = std::make_unique>( address, port, client_call_manager); node_resource_info_grpc_client_ = - std::make_unique>(address, port, - client_call_manager); + std::make_unique>( + address, port, client_call_manager); heartbeat_info_grpc_client_ = std::make_unique>( address, port, client_call_manager); stats_grpc_client_ = @@ -157,8 +165,8 @@ class GcsRpcClient { worker_info_grpc_client_ = std::make_unique>( address, port, client_call_manager); placement_group_info_grpc_client_ = - std::make_unique>(address, port, - client_call_manager); + std::make_unique>( + address, port, client_call_manager); internal_kv_grpc_client_ = std::make_unique>( address, port, client_call_manager); internal_pubsub_grpc_client_ = std::make_unique>( @@ -166,165 +174,243 @@ class GcsRpcClient { } /// Add job info to GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, AddJob, job_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, + AddJob, + job_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Mark job as finished to GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, MarkJobFinished, job_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, + MarkJobFinished, + job_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Get information of all jobs from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, GetAllJobInfo, job_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, + GetAllJobInfo, + job_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Report job error to GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, ReportJobError, job_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, + ReportJobError, + job_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Get next job id from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, GetNextJobID, job_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, + GetNextJobID, + job_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Register actor via GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, RegisterActor, actor_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, + RegisterActor, + actor_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Create actor via GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, CreateActor, actor_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, + CreateActor, + actor_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Get actor data from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, GetActorInfo, actor_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, + GetActorInfo, + actor_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Get actor data from GCS Service by name. - VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, GetNamedActorInfo, - actor_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, + GetNamedActorInfo, + actor_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Get all named actor names from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, ListNamedActors, - actor_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, + ListNamedActors, + actor_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Get all actor data from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, GetAllActorInfo, - actor_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, + GetAllActorInfo, + actor_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Kill actor via GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, KillActorViaGcs, - actor_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(ActorInfoGcsService, + KillActorViaGcs, + actor_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Register a node to GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, RegisterNode, node_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, + RegisterNode, + node_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Unregister a node from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, DrainNode, node_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, + DrainNode, + node_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Get information of all nodes from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, GetAllNodeInfo, node_info_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, + GetAllNodeInfo, + node_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Get internal config of the node from the GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, GetInternalConfig, - node_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, + GetInternalConfig, + node_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Get node's resources from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(NodeResourceInfoGcsService, GetResources, - node_resource_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(NodeResourceInfoGcsService, + GetResources, + node_resource_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Get available resources of all nodes from the GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(NodeResourceInfoGcsService, GetAllAvailableResources, - node_resource_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(NodeResourceInfoGcsService, + GetAllAvailableResources, + node_resource_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Report resource usage of a node to GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(NodeResourceInfoGcsService, ReportResourceUsage, - node_resource_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(NodeResourceInfoGcsService, + ReportResourceUsage, + node_resource_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Get resource usage of all nodes from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(NodeResourceInfoGcsService, GetAllResourceUsage, - node_resource_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(NodeResourceInfoGcsService, + GetAllResourceUsage, + node_resource_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Report heartbeat of a node to GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(HeartbeatInfoGcsService, ReportHeartbeat, - heartbeat_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(HeartbeatInfoGcsService, + ReportHeartbeat, + heartbeat_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Check GCS is alive. - VOID_GCS_RPC_CLIENT_METHOD(HeartbeatInfoGcsService, CheckAlive, - heartbeat_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(HeartbeatInfoGcsService, + CheckAlive, + heartbeat_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Add profile data to GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(StatsGcsService, AddProfileData, stats_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(StatsGcsService, + AddProfileData, + stats_grpc_client_, /*method_timeout_ms*/ -1, ) /// Get information of all profiles from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(StatsGcsService, GetAllProfileInfo, stats_grpc_client_, + VOID_GCS_RPC_CLIENT_METHOD(StatsGcsService, + GetAllProfileInfo, + stats_grpc_client_, /*method_timeout_ms*/ -1, ) /// Report a worker failure to GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, ReportWorkerFailure, - worker_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, + ReportWorkerFailure, + worker_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Get worker information from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, GetWorkerInfo, - worker_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, + GetWorkerInfo, + worker_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Get information of all workers from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, GetAllWorkerInfo, - worker_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, + GetAllWorkerInfo, + worker_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Add worker information to GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, AddWorkerInfo, - worker_info_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, + AddWorkerInfo, + worker_info_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Create placement group via GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, CreatePlacementGroup, + VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, + CreatePlacementGroup, placement_group_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Remove placement group via GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, RemovePlacementGroup, + VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, + RemovePlacementGroup, placement_group_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Get placement group via GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetPlacementGroup, + VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, + GetPlacementGroup, placement_group_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Get placement group data from GCS Service by name. - VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetNamedPlacementGroup, + VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, + GetNamedPlacementGroup, placement_group_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Get information of all placement group from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetAllPlacementGroup, + VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, + GetAllPlacementGroup, placement_group_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Wait for placement group until ready via GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, WaitPlacementGroupUntilReady, + VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, + WaitPlacementGroupUntilReady, placement_group_info_grpc_client_, /*method_timeout_ms*/ -1, ) /// Operations for kv (Get, Put, Del, Exists) - VOID_GCS_RPC_CLIENT_METHOD(InternalKVGcsService, InternalKVGet, - internal_kv_grpc_client_, /*method_timeout_ms*/ -1, ) - VOID_GCS_RPC_CLIENT_METHOD(InternalKVGcsService, InternalKVPut, - internal_kv_grpc_client_, /*method_timeout_ms*/ -1, ) - VOID_GCS_RPC_CLIENT_METHOD(InternalKVGcsService, InternalKVDel, - internal_kv_grpc_client_, /*method_timeout_ms*/ -1, ) - VOID_GCS_RPC_CLIENT_METHOD(InternalKVGcsService, InternalKVExists, - internal_kv_grpc_client_, /*method_timeout_ms*/ -1, ) - VOID_GCS_RPC_CLIENT_METHOD(InternalKVGcsService, InternalKVKeys, - internal_kv_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(InternalKVGcsService, + InternalKVGet, + internal_kv_grpc_client_, + /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(InternalKVGcsService, + InternalKVPut, + internal_kv_grpc_client_, + /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(InternalKVGcsService, + InternalKVDel, + internal_kv_grpc_client_, + /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(InternalKVGcsService, + InternalKVExists, + internal_kv_grpc_client_, + /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(InternalKVGcsService, + InternalKVKeys, + internal_kv_grpc_client_, + /*method_timeout_ms*/ -1, ) /// Operations for pubsub - VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsPublish, - internal_pubsub_grpc_client_, /*method_timeout_ms*/ -1, ) - VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberPoll, - internal_pubsub_grpc_client_, /*method_timeout_ms*/ -1, ) - VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberCommandBatch, - internal_pubsub_grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, + GcsPublish, + internal_pubsub_grpc_client_, + /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, + GcsSubscriberPoll, + internal_pubsub_grpc_client_, + /*method_timeout_ms*/ -1, ) + VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, + GcsSubscriberCommandBatch, + internal_pubsub_grpc_client_, + /*method_timeout_ms*/ -1, ) private: std::function gcs_service_failure_detected_; diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index 90c15b7fa..24459d2b8 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -23,38 +23,44 @@ namespace ray { namespace rpc { -#define JOB_INFO_SERVICE_RPC_HANDLER(HANDLER) \ - RPC_SERVICE_HANDLER(JobInfoGcsService, HANDLER, \ +#define JOB_INFO_SERVICE_RPC_HANDLER(HANDLER) \ + RPC_SERVICE_HANDLER(JobInfoGcsService, \ + HANDLER, \ RayConfig::instance().gcs_max_active_rpcs_per_handler()) #define ACTOR_INFO_SERVICE_RPC_HANDLER(HANDLER, MAX_ACTIVE_RPCS) \ RPC_SERVICE_HANDLER(ActorInfoGcsService, HANDLER, MAX_ACTIVE_RPCS) -#define NODE_INFO_SERVICE_RPC_HANDLER(HANDLER) \ - RPC_SERVICE_HANDLER(NodeInfoGcsService, HANDLER, \ +#define NODE_INFO_SERVICE_RPC_HANDLER(HANDLER) \ + RPC_SERVICE_HANDLER(NodeInfoGcsService, \ + HANDLER, \ RayConfig::instance().gcs_max_active_rpcs_per_handler()) #define HEARTBEAT_INFO_SERVICE_RPC_HANDLER(HANDLER) \ RPC_SERVICE_HANDLER(HeartbeatInfoGcsService, HANDLER, -1) -#define NODE_RESOURCE_INFO_SERVICE_RPC_HANDLER(HANDLER) \ - RPC_SERVICE_HANDLER(NodeResourceInfoGcsService, HANDLER, \ +#define NODE_RESOURCE_INFO_SERVICE_RPC_HANDLER(HANDLER) \ + RPC_SERVICE_HANDLER(NodeResourceInfoGcsService, \ + HANDLER, \ RayConfig::instance().gcs_max_active_rpcs_per_handler()) -#define OBJECT_INFO_SERVICE_RPC_HANDLER(HANDLER) \ - RPC_SERVICE_HANDLER(ObjectInfoGcsService, HANDLER, \ +#define OBJECT_INFO_SERVICE_RPC_HANDLER(HANDLER) \ + RPC_SERVICE_HANDLER(ObjectInfoGcsService, \ + HANDLER, \ RayConfig::instance().gcs_max_active_rpcs_per_handler()) -#define STATS_SERVICE_RPC_HANDLER(HANDLER) \ - RPC_SERVICE_HANDLER(StatsGcsService, HANDLER, \ +#define STATS_SERVICE_RPC_HANDLER(HANDLER) \ + RPC_SERVICE_HANDLER( \ + StatsGcsService, HANDLER, RayConfig::instance().gcs_max_active_rpcs_per_handler()) + +#define WORKER_INFO_SERVICE_RPC_HANDLER(HANDLER) \ + RPC_SERVICE_HANDLER(WorkerInfoGcsService, \ + HANDLER, \ RayConfig::instance().gcs_max_active_rpcs_per_handler()) -#define WORKER_INFO_SERVICE_RPC_HANDLER(HANDLER) \ - RPC_SERVICE_HANDLER(WorkerInfoGcsService, HANDLER, \ - RayConfig::instance().gcs_max_active_rpcs_per_handler()) - -#define PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(HANDLER) \ - RPC_SERVICE_HANDLER(PlacementGroupInfoGcsService, HANDLER, \ +#define PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(HANDLER) \ + RPC_SERVICE_HANDLER(PlacementGroupInfoGcsService, \ + HANDLER, \ RayConfig::instance().gcs_max_active_rpcs_per_handler()) #define INTERNAL_KV_SERVICE_RPC_HANDLER(HANDLER) \ @@ -73,7 +79,8 @@ class JobInfoGcsServiceHandler { public: virtual ~JobInfoGcsServiceHandler() = default; - virtual void HandleAddJob(const AddJobRequest &request, AddJobReply *reply, + virtual void HandleAddJob(const AddJobRequest &request, + AddJobReply *reply, SendReplyCallback send_reply_callback) = 0; virtual void HandleMarkJobFinished(const MarkJobFinishedRequest &request, @@ -208,7 +215,8 @@ class NodeInfoGcsServiceHandler { RegisterNodeReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void HandleDrainNode(const DrainNodeRequest &request, DrainNodeReply *reply, + virtual void HandleDrainNode(const DrainNodeRequest &request, + DrainNodeReply *reply, SendReplyCallback send_reply_callback) = 0; virtual void HandleGetAllNodeInfo(const GetAllNodeInfoRequest &request, @@ -306,7 +314,8 @@ class HeartbeatInfoGcsServiceHandler { virtual void HandleReportHeartbeat(const ReportHeartbeatRequest &request, ReportHeartbeatReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void HandleCheckAlive(const CheckAliveRequest &request, CheckAliveReply *reply, + virtual void HandleCheckAlive(const CheckAliveRequest &request, + CheckAliveReply *reply, SendReplyCallback send_reply_callback) = 0; }; /// The `GrpcService` for `HeartbeatInfoGcsService`. @@ -537,7 +546,8 @@ class InternalPubSubGcsServiceHandler { public: virtual ~InternalPubSubGcsServiceHandler() = default; - virtual void HandleGcsPublish(const GcsPublishRequest &request, GcsPublishReply *reply, + virtual void HandleGcsPublish(const GcsPublishRequest &request, + GcsPublishReply *reply, SendReplyCallback send_reply_callback) = 0; virtual void HandleGcsSubscriberPoll(const GcsSubscriberPollRequest &request, @@ -546,7 +556,8 @@ class InternalPubSubGcsServiceHandler { virtual void HandleGcsSubscriberCommandBatch( const GcsSubscriberCommandBatchRequest &request, - GcsSubscriberCommandBatchReply *reply, SendReplyCallback send_reply_callback) = 0; + GcsSubscriberCommandBatchReply *reply, + SendReplyCallback send_reply_callback) = 0; }; class InternalPubSubGrpcService : public GrpcService { diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index d1e576a29..087d30fcb 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -29,11 +29,14 @@ namespace rpc { // This macro wraps the logic to call a specific RPC method of a service, // to make it easier to implement a new RPC client. -#define INVOKE_RPC_CALL(SERVICE, METHOD, request, callback, rpc_client, \ - method_timeout_ms) \ - (rpc_client->CallMethod( \ - &SERVICE::Stub::PrepareAsync##METHOD, request, callback, \ - #SERVICE ".grpc_client." #METHOD, method_timeout_ms)) +#define INVOKE_RPC_CALL( \ + SERVICE, METHOD, request, callback, rpc_client, method_timeout_ms) \ + (rpc_client->CallMethod( \ + &SERVICE::Stub::PrepareAsync##METHOD, \ + request, \ + callback, \ + #SERVICE ".grpc_client." #METHOD, \ + method_timeout_ms)) // Define a void RPC client method. #define VOID_RPC_CLIENT_METHOD(SERVICE, METHOD, rpc_client, method_timeout_ms, SPECS) \ @@ -45,7 +48,9 @@ namespace rpc { template class GrpcClient { public: - GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, + GrpcClient(const std::string &address, + const int port, + ClientCallManager &call_manager, bool use_tls = false) : client_call_manager_(call_manager), use_tls_(use_tls) { grpc::ChannelArguments argument; @@ -60,8 +65,11 @@ class GrpcClient { stub_ = GrpcService::NewStub(channel); } - GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, - int num_threads, bool use_tls = false) + GrpcClient(const std::string &address, + const int port, + ClientCallManager &call_manager, + int num_threads, + bool use_tls = false) : client_call_manager_(call_manager), use_tls_(use_tls) { grpc::ResourceQuota quota; quota.SetMaxThreads(num_threads); @@ -93,10 +101,16 @@ class GrpcClient { template void CallMethod( const PrepareAsyncFunction prepare_async_function, - const Request &request, const ClientCallback &callback, - std::string call_name = "UNKNOWN_RPC", int64_t method_timeout_ms = -1) { + const Request &request, + const ClientCallback &callback, + std::string call_name = "UNKNOWN_RPC", + int64_t method_timeout_ms = -1) { auto call = client_call_manager_.CreateCall( - *stub_, prepare_async_function, request, callback, std::move(call_name), + *stub_, + prepare_async_function, + request, + callback, + std::move(call_name), method_timeout_ms); RAY_CHECK(call != nullptr); } @@ -109,7 +123,8 @@ class GrpcClient { bool use_tls_; std::shared_ptr BuildChannel(const grpc::ChannelArguments &argument, - const std::string &address, int port) { + const std::string &address, + int port) { std::shared_ptr channel; if (::RayConfig::instance().USE_TLS()) { std::string server_cert_file = @@ -125,11 +140,12 @@ class GrpcClient { ssl_opts.pem_private_key = private_key; ssl_opts.pem_cert_chain = server_cert_chain; auto ssl_creds = grpc::SslCredentials(ssl_opts); - channel = grpc::CreateCustomChannel(address + ":" + std::to_string(port), ssl_creds, - argument); + channel = grpc::CreateCustomChannel( + address + ":" + std::to_string(port), ssl_creds, argument); } else { channel = grpc::CreateCustomChannel(address + ":" + std::to_string(port), - grpc::InsecureChannelCredentials(), argument); + grpc::InsecureChannelCredentials(), + argument); } return channel; }; diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index b053e9bd5..820612f98 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -27,8 +27,10 @@ namespace ray { namespace rpc { -GrpcServer::GrpcServer(std::string name, const uint32_t port, - bool listen_to_localhost_only, int num_threads, +GrpcServer::GrpcServer(std::string name, + const uint32_t port, + bool listen_to_localhost_only, + int num_threads, int64_t keepalive_time_ms) : name_(std::move(name)), port_(port), diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 843c0acba..c03d87cda 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -28,13 +28,20 @@ namespace ray { namespace rpc { /// \param MAX_ACTIVE_RPCS Maximum number of RPCs to handle at the same time. -1 means no /// limit. -#define RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS) \ - std::unique_ptr HANDLER##_call_factory( \ - new ServerCallFactoryImpl( \ - service_, &SERVICE::AsyncService::Request##HANDLER, service_handler_, \ - &SERVICE##Handler::Handle##HANDLER, cq, main_service_, \ - #SERVICE ".grpc_server." #HANDLER, MAX_ACTIVE_RPCS)); \ +#define RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS) \ + std::unique_ptr HANDLER##_call_factory( \ + new ServerCallFactoryImpl( \ + service_, \ + &SERVICE::AsyncService::Request##HANDLER, \ + service_handler_, \ + &SERVICE##Handler::Handle##HANDLER, \ + cq, \ + main_service_, \ + #SERVICE ".grpc_server." #HANDLER, \ + MAX_ACTIVE_RPCS)); \ server_call_factories->emplace_back(std::move(HANDLER##_call_factory)); // Define a void RPC client method. @@ -61,7 +68,9 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - GrpcServer(std::string name, const uint32_t port, bool listen_to_localhost_only, + GrpcServer(std::string name, + const uint32_t port, + bool listen_to_localhost_only, int num_threads = 1, int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/); diff --git a/src/ray/rpc/metrics_agent_client.h b/src/ray/rpc/metrics_agent_client.h index 89af932ee..f14f435f5 100644 --- a/src/ray/rpc/metrics_agent_client.h +++ b/src/ray/rpc/metrics_agent_client.h @@ -35,7 +35,8 @@ class MetricsAgentClient { /// \param[in] address Address of the metrics agent server. /// \param[in] port Port of the metrics agent server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - MetricsAgentClient(const std::string &address, const int port, + MetricsAgentClient(const std::string &address, + const int port, ClientCallManager &client_call_manager) { RAY_LOG(DEBUG) << "Initiate the metrics client of address:" << address << " port:" << port; @@ -47,14 +48,18 @@ class MetricsAgentClient { /// /// \param[in] request The request message. /// \param[in] callback The callback function that handles reply. - VOID_RPC_CLIENT_METHOD(ReporterService, ReportMetrics, grpc_client_, + VOID_RPC_CLIENT_METHOD(ReporterService, + ReportMetrics, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Report open census protobuf metrics to metrics agent. /// /// \param[in] request The request message. /// \param[in] callback The callback function that handles reply. - VOID_RPC_CLIENT_METHOD(ReporterService, ReportOCMetrics, grpc_client_, + VOID_RPC_CLIENT_METHOD(ReporterService, + ReportOCMetrics, + grpc_client_, /*method_timeout_ms*/ -1, ) private: diff --git a/src/ray/rpc/node_manager/node_manager_client.h b/src/ray/rpc/node_manager/node_manager_client.h index 9b7bd2d93..e22f27476 100644 --- a/src/ray/rpc/node_manager/node_manager_client.h +++ b/src/ray/rpc/node_manager/node_manager_client.h @@ -35,14 +35,17 @@ class NodeManagerClient { /// \param[in] address Address of the node manager server. /// \param[in] port Port of the node manager server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - NodeManagerClient(const std::string &address, const int port, + NodeManagerClient(const std::string &address, + const int port, ClientCallManager &client_call_manager) { - grpc_client_ = std::make_unique>(address, port, - client_call_manager); + grpc_client_ = std::make_unique>( + address, port, client_call_manager); }; /// Get current node stats. - VOID_RPC_CLIENT_METHOD(NodeManagerService, GetNodeStats, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + GetNodeStats, + grpc_client_, /*method_timeout_ms*/ -1, ) void GetNodeStats(const ClientCallback &callback) { @@ -65,78 +68,113 @@ class NodeManagerWorkerClient /// \param[in] port Port of the node manager server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. static std::shared_ptr make( - const std::string &address, const int port, + const std::string &address, + const int port, ClientCallManager &client_call_manager) { auto instance = new NodeManagerWorkerClient(address, port, client_call_manager); return std::shared_ptr(instance); } /// Update cluster resource usage. - VOID_RPC_CLIENT_METHOD(NodeManagerService, UpdateResourceUsage, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + UpdateResourceUsage, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Request a resource report. - VOID_RPC_CLIENT_METHOD(NodeManagerService, RequestResourceReport, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + RequestResourceReport, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Request a worker lease. - VOID_RPC_CLIENT_METHOD(NodeManagerService, RequestWorkerLease, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + RequestWorkerLease, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Report task backlog information - VOID_RPC_CLIENT_METHOD(NodeManagerService, ReportWorkerBacklog, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + ReportWorkerBacklog, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Return a worker lease. - VOID_RPC_CLIENT_METHOD(NodeManagerService, ReturnWorker, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + ReturnWorker, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Release unused workers. - VOID_RPC_CLIENT_METHOD(NodeManagerService, ReleaseUnusedWorkers, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + ReleaseUnusedWorkers, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Shutdown the raylet gracefully. - VOID_RPC_CLIENT_METHOD(NodeManagerService, ShutdownRaylet, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + ShutdownRaylet, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Cancel a pending worker lease request. - VOID_RPC_CLIENT_METHOD(NodeManagerService, CancelWorkerLease, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + CancelWorkerLease, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Request prepare resources for an atomic placement group creation. - VOID_RPC_CLIENT_METHOD(NodeManagerService, PrepareBundleResources, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + PrepareBundleResources, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Request commit resources for an atomic placement group creation. - VOID_RPC_CLIENT_METHOD(NodeManagerService, CommitBundleResources, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + CommitBundleResources, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Return resource lease. - VOID_RPC_CLIENT_METHOD(NodeManagerService, CancelResourceReserve, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + CancelResourceReserve, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Notify the raylet to pin the provided object IDs. - VOID_RPC_CLIENT_METHOD(NodeManagerService, PinObjectIDs, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + PinObjectIDs, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Trigger global GC across the cluster. - VOID_RPC_CLIENT_METHOD(NodeManagerService, GlobalGC, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + GlobalGC, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Ask the raylet to spill an object to external storage. - VOID_RPC_CLIENT_METHOD(NodeManagerService, RequestObjectSpillage, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + RequestObjectSpillage, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Release unused bundles. - VOID_RPC_CLIENT_METHOD(NodeManagerService, ReleaseUnusedBundles, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + ReleaseUnusedBundles, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Get the system config from Raylet. - VOID_RPC_CLIENT_METHOD(NodeManagerService, GetSystemConfig, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + GetSystemConfig, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Get gcs server address. - VOID_RPC_CLIENT_METHOD(NodeManagerService, GetGcsServerAddress, grpc_client_, + VOID_RPC_CLIENT_METHOD(NodeManagerService, + GetGcsServerAddress, + grpc_client_, /*method_timeout_ms*/ -1, ) private: @@ -145,10 +183,11 @@ class NodeManagerWorkerClient /// \param[in] address Address of the node manager server. /// \param[in] port Port of the node manager server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - NodeManagerWorkerClient(const std::string &address, const int port, + NodeManagerWorkerClient(const std::string &address, + const int port, ClientCallManager &client_call_manager) { - grpc_client_ = std::make_unique>(address, port, - client_call_manager); + grpc_client_ = std::make_unique>( + address, port, client_call_manager); }; /// The RPC client. diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index e1fac4176..fe457aa7d 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -115,7 +115,8 @@ class NodeManagerServiceHandler { GetNodeStatsReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void HandleGlobalGC(const GlobalGCRequest &request, GlobalGCReply *reply, + virtual void HandleGlobalGC(const GlobalGCRequest &request, + GlobalGCReply *reply, SendReplyCallback send_reply_callback) = 0; virtual void HandleFormatGlobalMemoryInfo(const FormatGlobalMemoryInfoRequest &request, diff --git a/src/ray/rpc/object_manager/object_manager_client.h b/src/ray/rpc/object_manager/object_manager_client.h index e4947b1e6..2db29b950 100644 --- a/src/ray/rpc/object_manager/object_manager_client.h +++ b/src/ray/rpc/object_manager/object_manager_client.h @@ -37,8 +37,10 @@ class ObjectManagerClient { /// \param[in] address Address of the node manager server. /// \param[in] port Port of the node manager server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - ObjectManagerClient(const std::string &address, const int port, - ClientCallManager &client_call_manager, int num_connections = 4) + ObjectManagerClient(const std::string &address, + const int port, + ClientCallManager &client_call_manager, + int num_connections = 4) : num_connections_(num_connections) { push_rr_index_ = rand() % num_connections_; pull_rr_index_ = rand() % num_connections_; @@ -54,7 +56,8 @@ class ObjectManagerClient { /// /// \param request The request message. /// \param callback The callback function that handles reply from server - VOID_RPC_CLIENT_METHOD(ObjectManagerService, Push, + VOID_RPC_CLIENT_METHOD(ObjectManagerService, + Push, grpc_clients_[push_rr_index_++ % num_connections_], /*method_timeout_ms*/ -1, ) @@ -62,7 +65,8 @@ class ObjectManagerClient { /// /// \param request The request message /// \param callback The callback function that handles reply from server - VOID_RPC_CLIENT_METHOD(ObjectManagerService, Pull, + VOID_RPC_CLIENT_METHOD(ObjectManagerService, + Pull, grpc_clients_[pull_rr_index_++ % num_connections_], /*method_timeout_ms*/ -1, ) @@ -70,7 +74,8 @@ class ObjectManagerClient { /// /// \param request The request message /// \param callback The callback function that handles reply - VOID_RPC_CLIENT_METHOD(ObjectManagerService, FreeObjects, + VOID_RPC_CLIENT_METHOD(ObjectManagerService, + FreeObjects, grpc_clients_[freeobjects_rr_index_++ % num_connections_], /*method_timeout_ms*/ -1, ) diff --git a/src/ray/rpc/object_manager/object_manager_server.h b/src/ray/rpc/object_manager/object_manager_server.h index dfad13cbc..bcd886861 100644 --- a/src/ray/rpc/object_manager/object_manager_server.h +++ b/src/ray/rpc/object_manager/object_manager_server.h @@ -39,10 +39,12 @@ class ObjectManagerServiceHandler { /// \param[in] request The request message. /// \param[out] reply The reply message. /// \param[in] send_reply_callback The callback to be called when the request is done. - virtual void HandlePush(const PushRequest &request, PushReply *reply, + virtual void HandlePush(const PushRequest &request, + PushReply *reply, SendReplyCallback send_reply_callback) = 0; /// Handle a `Pull` request - virtual void HandlePull(const PullRequest &request, PullReply *reply, + virtual void HandlePull(const PullRequest &request, + PullReply *reply, SendReplyCallback send_reply_callback) = 0; /// Handle a `FreeObjects` request virtual void HandleFreeObjects(const FreeObjectsRequest &request, diff --git a/src/ray/rpc/runtime_env/runtime_env_client.h b/src/ray/rpc/runtime_env/runtime_env_client.h index 849b18847..4192233c5 100644 --- a/src/ray/rpc/runtime_env/runtime_env_client.h +++ b/src/ray/rpc/runtime_env/runtime_env_client.h @@ -39,24 +39,29 @@ class RuntimeEnvAgentClient : public RuntimeEnvAgentClientInterface { /// \param[in] address Address of the server. /// \param[in] port Port of the server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - RuntimeEnvAgentClient(const std::string &address, const int port, + RuntimeEnvAgentClient(const std::string &address, + const int port, ClientCallManager &client_call_manager) { - grpc_client_ = std::make_unique>(address, port, - client_call_manager); + grpc_client_ = std::make_unique>( + address, port, client_call_manager); }; /// Create runtime env. /// /// \param request The request message /// \param callback The callback function that handles reply - VOID_RPC_CLIENT_METHOD(RuntimeEnvService, CreateRuntimeEnv, grpc_client_, + VOID_RPC_CLIENT_METHOD(RuntimeEnvService, + CreateRuntimeEnv, + grpc_client_, /*method_timeout_ms*/ -1, ) /// Delete URIs. /// /// \param request The request message /// \param callback The callback function that handles reply - VOID_RPC_CLIENT_METHOD(RuntimeEnvService, DeleteURIs, grpc_client_, + VOID_RPC_CLIENT_METHOD(RuntimeEnvService, + DeleteURIs, + grpc_client_, /*method_timeout_ms*/ -1, ) private: diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index a25f4da31..7b42c5ac8 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -35,8 +35,8 @@ namespace rpc { /// sent to the client. /// \param failure Failure callback which will be invoked when the reply fails to be /// sent to the client. -using SendReplyCallback = std::function success, - std::function failure)>; +using SendReplyCallback = std::function success, std::function failure)>; /// Represents state of a `ServerCall`. enum class ServerCallState { @@ -113,7 +113,8 @@ class ServerCallFactory { /// \tparam Request Type of the request message. /// \tparam Reply Type of the reply message. template -using HandleRequestFunction = void (ServiceHandler::*)(const Request &, Reply *, +using HandleRequestFunction = void (ServiceHandler::*)(const Request &, + Reply *, SendReplyCallback); /// Implementation of `ServerCall`. It represents `ServerCall` for a particular @@ -132,9 +133,11 @@ class ServerCallImpl : public ServerCall { /// \param[in] handle_request_function Pointer to the service handler function. /// \param[in] io_service The event loop. ServerCallImpl( - const ServerCallFactory &factory, ServiceHandler &service_handler, + const ServerCallFactory &factory, + ServiceHandler &service_handler, HandleRequestFunction handle_request_function, - instrumented_io_context &io_service, std::string call_name) + instrumented_io_context &io_service, + std::string call_name) : state_(ServerCallState::PENDING), factory_(factory), service_handler_(service_handler), @@ -181,9 +184,10 @@ class ServerCallImpl : public ServerCall { factory.CreateCall(); } (service_handler_.*handle_request_function_)( - request_, reply_, - [this](Status status, std::function success, - std::function failure) { + request_, + reply_, + [this]( + Status status, std::function success, std::function failure) { // These two callbacks must be set before `SendReply`, because `SendReply` // is async and this `ServerCall` might be deleted right after `SendReply`. send_reply_success_callback_ = std::move(success); @@ -284,9 +288,13 @@ class ServerCallImpl : public ServerCall { /// \tparam Request Type of the request message. /// \tparam Reply Type of the reply message. template -using RequestCallFunction = void (GrpcService::AsyncService::*)( - grpc::ServerContext *, Request *, grpc::ServerAsyncResponseWriter *, - grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *); +using RequestCallFunction = + void (GrpcService::AsyncService::*)(grpc::ServerContext *, + Request *, + grpc::ServerAsyncResponseWriter *, + grpc::CompletionQueue *, + grpc::ServerCompletionQueue *, + void *); /// Implementation of `ServerCallFactory` /// @@ -316,7 +324,9 @@ class ServerCallFactoryImpl : public ServerCallFactory { ServiceHandler &service_handler, HandleRequestFunction handle_request_function, const std::unique_ptr &cq, - instrumented_io_context &io_service, std::string call_name, int64_t max_active_rpcs) + instrumented_io_context &io_service, + std::string call_name, + int64_t max_active_rpcs) : service_(service), request_call_function_(request_call_function), service_handler_(service_handler), @@ -333,8 +343,11 @@ class ServerCallFactoryImpl : public ServerCallFactory { *this, service_handler_, handle_request_function_, io_service_, call_name_); /// Request gRPC runtime to starting accepting this kind of request, using the call as /// the tag. - (service_.*request_call_function_)(&call->context_, &call->request_, - &call->response_writer_, cq_.get(), cq_.get(), + (service_.*request_call_function_)(&call->context_, + &call->request_, + &call->response_writer_, + cq_.get(), + cq_.get(), call); } diff --git a/src/ray/rpc/test/grpc_server_client_test.cc b/src/ray/rpc/test/grpc_server_client_test.cc index 18d9294fb..46b08ccd7 100644 --- a/src/ray/rpc/test/grpc_server_client_test.cc +++ b/src/ray/rpc/test/grpc_server_client_test.cc @@ -23,7 +23,8 @@ namespace ray { namespace rpc { class TestServiceHandler { public: - void HandlePing(const PingRequest &request, PingReply *reply, + void HandlePing(const PingRequest &request, + PingReply *reply, SendReplyCallback send_reply_callback) { RAY_LOG(INFO) << "Got ping request, no_reply=" << request.no_reply(); request_count++; @@ -46,7 +47,8 @@ class TestServiceHandler { }); } - void HandlePingTimeout(const PingTimeoutRequest &request, PingTimeoutReply *reply, + void HandlePingTimeout(const PingTimeoutRequest &request, + PingTimeoutReply *reply, SendReplyCallback send_reply_callback) { while (frozen) { RAY_LOG(INFO) << "Server is frozen..."; @@ -120,8 +122,8 @@ class TestGrpcServerClientFixture : public ::testing::Test { client_io_service_.run(); }); client_call_manager_.reset(new ClientCallManager(client_io_service_)); - grpc_client_.reset(new GrpcClient("127.0.0.1", grpc_server_->GetPort(), - *client_call_manager_)); + grpc_client_.reset(new GrpcClient( + "127.0.0.1", grpc_server_->GetPort(), *client_call_manager_)); } void ShutdownClient() { @@ -149,7 +151,9 @@ class TestGrpcServerClientFixture : public ::testing::Test { protected: VOID_RPC_CLIENT_METHOD(TestService, Ping, grpc_client_, /*method_timeout_ms*/ -1, ) - VOID_RPC_CLIENT_METHOD(TestService, PingTimeout, grpc_client_, + VOID_RPC_CLIENT_METHOD(TestService, + PingTimeout, + grpc_client_, /*method_timeout_ms*/ 100, ) // Server TestServiceHandler test_service_handler_; @@ -205,10 +209,11 @@ TEST_F(TestGrpcServerClientFixture, TestClientCallManagerTimeout) { // Reinit ClientCallManager with short timeout. grpc_client_.reset(); client_call_manager_.reset(); - client_call_manager_.reset(new ClientCallManager(client_io_service_, /*num_thread=*/1, + client_call_manager_.reset(new ClientCallManager(client_io_service_, + /*num_thread=*/1, /*call_timeout_ms=*/100)); - grpc_client_.reset(new GrpcClient("127.0.0.1", grpc_server_->GetPort(), - *client_call_manager_)); + grpc_client_.reset(new GrpcClient( + "127.0.0.1", grpc_server_->GetPort(), *client_call_manager_)); // Freeze server first, it won't reply any request. test_service_handler_.frozen = true; // Send request. @@ -237,10 +242,11 @@ TEST_F(TestGrpcServerClientFixture, TestClientDiedBeforeReply) { // Reinit ClientCallManager with short timeout, so that call won't block. grpc_client_.reset(); client_call_manager_.reset(); - client_call_manager_.reset(new ClientCallManager(client_io_service_, /*num_thread=*/1, + client_call_manager_.reset(new ClientCallManager(client_io_service_, + /*num_thread=*/1, /*call_timeout_ms=*/100)); - grpc_client_.reset(new GrpcClient("127.0.0.1", grpc_server_->GetPort(), - *client_call_manager_)); + grpc_client_.reset(new GrpcClient( + "127.0.0.1", grpc_server_->GetPort(), *client_call_manager_)); // Freeze server first, it won't reply any request. test_service_handler_.frozen = true; // Send request. @@ -267,8 +273,8 @@ TEST_F(TestGrpcServerClientFixture, TestClientDiedBeforeReply) { } // Reinit client with infinite timeout. client_call_manager_.reset(new ClientCallManager(client_io_service_)); - grpc_client_.reset(new GrpcClient("127.0.0.1", grpc_server_->GetPort(), - *client_call_manager_)); + grpc_client_.reset(new GrpcClient( + "127.0.0.1", grpc_server_->GetPort(), *client_call_manager_)); // Send again, this request should be replied. If any leaking happened, this call won't // be replied to since the max_active_rpcs is 1. std::atomic done(false); diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index 1cd31a5af..f63e75b1f 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -109,7 +109,8 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface { /// task for execution immediately. /// \param[in] callback The callback function that handles reply. /// \return if the rpc call succeeds - virtual void PushActorTask(std::unique_ptr request, bool skip_queue, + virtual void PushActorTask(std::unique_ptr request, + bool skip_queue, const ClientCallback &callback) {} /// Similar to PushActorTask, but sets no ordering constraint. This is used to @@ -218,71 +219,130 @@ class CoreWorkerClient : public std::enable_shared_from_this, const rpc::Address &Addr() const override { return addr_; } - VOID_RPC_CLIENT_METHOD(CoreWorkerService, DirectActorCallArgWaitComplete, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, GetObjectStatus, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, KillActor, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, CancelTask, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, RemoteCancelTask, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, WaitForActorOutOfScope, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, PubsubLongPolling, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, PubsubCommandBatch, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, UpdateObjectLocationBatch, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, GetObjectLocationsOwner, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, GetCoreWorkerStats, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, LocalGC, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, SpillObjects, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, RestoreSpilledObjects, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, DeleteSpilledObjects, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, AddSpilledUrl, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, PlasmaObjectReady, grpc_client_, - /*method_timeout_ms*/ -1, override) - - VOID_RPC_CLIENT_METHOD(CoreWorkerService, Exit, grpc_client_, /*method_timeout_ms*/ -1, + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + DirectActorCallArgWaitComplete, + grpc_client_, + /*method_timeout_ms*/ -1, override) - VOID_RPC_CLIENT_METHOD(CoreWorkerService, AssignObjectOwner, grpc_client_, - /*method_timeout_ms*/ -1, override) + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + GetObjectStatus, + grpc_client_, + /*method_timeout_ms*/ -1, + override) - void PushActorTask(std::unique_ptr request, bool skip_queue, + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + KillActor, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + CancelTask, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + RemoteCancelTask, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + WaitForActorOutOfScope, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + PubsubLongPolling, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + PubsubCommandBatch, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + UpdateObjectLocationBatch, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + GetObjectLocationsOwner, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + GetCoreWorkerStats, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + LocalGC, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + SpillObjects, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + RestoreSpilledObjects, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + DeleteSpilledObjects, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + AddSpilledUrl, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + PlasmaObjectReady, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + VOID_RPC_CLIENT_METHOD( + CoreWorkerService, Exit, grpc_client_, /*method_timeout_ms*/ -1, override) + + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + AssignObjectOwner, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + + void PushActorTask(std::unique_ptr request, + bool skip_queue, const ClientCallback &callback) override { if (skip_queue) { // Set this value so that the actor does not skip any tasks when // processing this request. We could also set it to max_finished_seq_no_, // but we just set it to the default of -1 to avoid taking the lock. request->set_client_processed_up_to(-1); - INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, callback, grpc_client_, + INVOKE_RPC_CALL(CoreWorkerService, + PushTask, + *request, + callback, + grpc_client_, /*method_timeout_ms*/ -1); return; } @@ -300,7 +360,11 @@ class CoreWorkerClient : public std::enable_shared_from_this, const ClientCallback &callback) override { request->set_sequence_number(-1); request->set_client_processed_up_to(-1); - INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, callback, grpc_client_, + INVOKE_RPC_CALL(CoreWorkerService, + PushTask, + *request, + callback, + grpc_client_, /*method_timeout_ms*/ -1); } @@ -323,23 +387,26 @@ class CoreWorkerClient : public std::enable_shared_from_this, request->set_client_processed_up_to(max_finished_seq_no_); rpc_bytes_in_flight_ += task_size; - auto rpc_callback = [this, this_ptr, seq_no, task_size, - callback = std::move(pair.second)]( - Status status, const rpc::PushTaskReply &reply) { - { - absl::MutexLock lock(&mutex_); - if (seq_no > max_finished_seq_no_) { - max_finished_seq_no_ = seq_no; - } - rpc_bytes_in_flight_ -= task_size; - RAY_CHECK(rpc_bytes_in_flight_ >= 0); - } - SendRequests(); - callback(status, reply); - }; + auto rpc_callback = + [this, this_ptr, seq_no, task_size, callback = std::move(pair.second)]( + Status status, const rpc::PushTaskReply &reply) { + { + absl::MutexLock lock(&mutex_); + if (seq_no > max_finished_seq_no_) { + max_finished_seq_no_ = seq_no; + } + rpc_bytes_in_flight_ -= task_size; + RAY_CHECK(rpc_bytes_in_flight_ >= 0); + } + SendRequests(); + callback(status, reply); + }; - RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, - std::move(rpc_callback), grpc_client_, + RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, + PushTask, + *request, + std::move(rpc_callback), + grpc_client_, /*method_timeout_ms*/ -1)); } diff --git a/src/ray/stats/metric.cc b/src/ray/stats/metric.cc index bb1a5be72..b27b3cace 100644 --- a/src/ray/stats/metric.cc +++ b/src/ray/stats/metric.cc @@ -122,7 +122,8 @@ void Metric::Record(double value, const std::unordered_map &tags) { TagsType tags_pair_vec; std::for_each( - tags.begin(), tags.end(), + tags.begin(), + tags.end(), [&tags_pair_vec](std::pair tag) { return tags_pair_vec.push_back({TagKeyType::Register(tag.first), tag.second}); }); diff --git a/src/ray/stats/metric.h b/src/ray/stats/metric.h index 8ab3ea682..d121a7506 100644 --- a/src/ray/stats/metric.h +++ b/src/ray/stats/metric.h @@ -101,7 +101,9 @@ class StatsConfig final { /// A thin wrapper that wraps the `opencensus::tag::measure` for using it simply. class Metric { public: - Metric(const std::string &name, const std::string &description, const std::string &unit, + Metric(const std::string &name, + const std::string &description, + const std::string &unit, const std::vector &tag_keys = {}) : name_(name), description_(description), @@ -148,7 +150,9 @@ class Metric { class Gauge : public Metric { public: - Gauge(const std::string &name, const std::string &description, const std::string &unit, + Gauge(const std::string &name, + const std::string &description, + const std::string &unit, const std::vector &tag_keys = {}) : Metric(name, description, unit, tag_keys) {} @@ -159,8 +163,10 @@ class Gauge : public Metric { class Histogram : public Metric { public: - Histogram(const std::string &name, const std::string &description, - const std::string &unit, const std::vector boundaries, + Histogram(const std::string &name, + const std::string &description, + const std::string &unit, + const std::vector boundaries, const std::vector &tag_keys = {}) : Metric(name, description, unit, tag_keys), boundaries_(boundaries) {} @@ -174,7 +180,9 @@ class Histogram : public Metric { class Count : public Metric { public: - Count(const std::string &name, const std::string &description, const std::string &unit, + Count(const std::string &name, + const std::string &description, + const std::string &unit, const std::vector &tag_keys = {}) : Metric(name, description, unit, tag_keys) {} @@ -185,7 +193,9 @@ class Count : public Metric { class Sum : public Metric { public: - Sum(const std::string &name, const std::string &description, const std::string &unit, + Sum(const std::string &name, + const std::string &description, + const std::string &unit, const std::vector &tag_keys = {}) : Metric(name, description, unit, tag_keys) {} @@ -247,7 +257,8 @@ struct StatsTypeMap { }; template -void RegisterView(const std::string &name, const std::string &description, +void RegisterView(const std::string &name, + const std::string &description, const std::vector &tag_keys, const std::vector &buckets) { using I = StatsTypeMap; @@ -260,14 +271,16 @@ void RegisterView(const std::string &name, const std::string &description, } template -void RegisterViewWithTagList(const std::string &name, const std::string &description, +void RegisterViewWithTagList(const std::string &name, + const std::string &description, const std::vector &tag_keys, const std::vector &buckets) { static_assert(std::is_same_v); } template -void RegisterViewWithTagList(const std::string &name, const std::string &description, +void RegisterViewWithTagList(const std::string &name, + const std::string &description, const std::vector &tag_keys, const std::vector &buckets) { RegisterView(name, description, tag_keys, buckets); @@ -297,12 +310,14 @@ class Stats { /// \param measure The name for the metric /// \description The description for the metric /// \register_func The function to register the metric - Stats(const std::string &measure, const std::string &description, - std::vector tag_keys, std::vector buckets, - std::function tag_keys, + std::vector buckets, + std::function, - const std::vector &buckets)> - register_func) + const std::vector &buckets)> register_func) : tag_keys_(convert_tags(tag_keys)) { auto stats_init = [register_func, measure, description, buckets, this]() { measure_ = std::make_unique(Measure::Register(measure, description, "")); @@ -402,7 +417,10 @@ class Stats { (), ray::stats::GAUGE); STATS_async_pool_req_execution_time_ms.record(1, "method"); */ -#define DEFINE_stats(name, description, tags, buckets, ...) \ - ray::stats::internal::Stats STATS_##name( \ - #name, description, {STATS_DEPAREN(tags)}, {STATS_DEPAREN(buckets)}, \ +#define DEFINE_stats(name, description, tags, buckets, ...) \ + ray::stats::internal::Stats STATS_##name( \ + #name, \ + description, \ + {STATS_DEPAREN(tags)}, \ + {STATS_DEPAREN(buckets)}, \ ray::stats::internal::RegisterViewWithTagList<__VA_ARGS__>) diff --git a/src/ray/stats/metric_defs.cc b/src/ray/stats/metric_defs.cc index 5fb2d2c81..fe3e5a8d1 100644 --- a/src/ray/stats/metric_defs.cc +++ b/src/ray/stats/metric_defs.cc @@ -33,87 +33,138 @@ namespace stats { /// Event stats DEFINE_stats(operation_count, "operation count", ("Method"), (), ray::stats::GAUGE); -DEFINE_stats(operation_run_time_ms, "operation execution time", ("Method"), (), - ray::stats::GAUGE); -DEFINE_stats(operation_queue_time_ms, "operation queuing time", ("Method"), (), - ray::stats::GAUGE); -DEFINE_stats(operation_active_count, "activate operation number", ("Method"), (), +DEFINE_stats( + operation_run_time_ms, "operation execution time", ("Method"), (), ray::stats::GAUGE); +DEFINE_stats( + operation_queue_time_ms, "operation queuing time", ("Method"), (), ray::stats::GAUGE); +DEFINE_stats(operation_active_count, + "activate operation number", + ("Method"), + (), ray::stats::GAUGE); /// GRPC server -DEFINE_stats(grpc_server_req_process_time_ms, "Request latency in grpc server", - ("Method"), (), ray::stats::GAUGE); -DEFINE_stats(grpc_server_req_new, "New request number in grpc server", ("Method"), (), +DEFINE_stats(grpc_server_req_process_time_ms, + "Request latency in grpc server", + ("Method"), + (), + ray::stats::GAUGE); +DEFINE_stats(grpc_server_req_new, + "New request number in grpc server", + ("Method"), + (), + ray::stats::COUNT); +DEFINE_stats(grpc_server_req_handling, + "Request number are handling in grpc server", + ("Method"), + (), + ray::stats::COUNT); +DEFINE_stats(grpc_server_req_finished, + "Finished request number in grpc server", + ("Method"), + (), ray::stats::COUNT); -DEFINE_stats(grpc_server_req_handling, "Request number are handling in grpc server", - ("Method"), (), ray::stats::COUNT); -DEFINE_stats(grpc_server_req_finished, "Finished request number in grpc server", - ("Method"), (), ray::stats::COUNT); /// Object Manager. DEFINE_stats(object_manager_received_chunks, "Number object chunks received broken per type {Total, FailedTotal, " "FailedCancelled, FailedPlasmaFull}.", - ("Type"), (), ray::stats::GAUGE); + ("Type"), + (), + ray::stats::GAUGE); /// Pull Manager DEFINE_stats( pull_manager_usage_bytes, "The total number of bytes usage broken per type {Available, BeingPulled, Pinned}", - ("Type"), (), ray::stats::GAUGE); + ("Type"), + (), + ray::stats::GAUGE); DEFINE_stats(pull_manager_requested_bundles, "Number of requested bundles broken per type {Get, Wait, TaskArgs}.", - ("Type"), (), ray::stats::GAUGE); + ("Type"), + (), + ray::stats::GAUGE); DEFINE_stats(pull_manager_requests, "Number of pull requests broken per type {Queued, Active, Pinned}.", - ("Type"), (), ray::stats::GAUGE); -DEFINE_stats(pull_manager_active_bundles, "Number of active bundle requests", (), (), + ("Type"), + (), ray::stats::GAUGE); -DEFINE_stats(pull_manager_retries_total, "Number of cumulative pull retries.", (), (), +DEFINE_stats(pull_manager_active_bundles, + "Number of active bundle requests", + (), + (), + ray::stats::GAUGE); +DEFINE_stats(pull_manager_retries_total, + "Number of cumulative pull retries.", + (), + (), ray::stats::GAUGE); /// Push Manager -DEFINE_stats(push_manager_in_flight_pushes, "Number of in flight object push requests.", - (), (), ray::stats::GAUGE); +DEFINE_stats(push_manager_in_flight_pushes, + "Number of in flight object push requests.", + (), + (), + ray::stats::GAUGE); DEFINE_stats(push_manager_chunks, "Number of object chunks transfer broken per type {InFlight, Remaining}.", - ("Type"), (), ray::stats::GAUGE); + ("Type"), + (), + ray::stats::GAUGE); /// Scheduler DEFINE_stats( scheduler_tasks, "Number of tasks waiting for scheduling broken per state {Cancelled, Executing, " "Waiting, Dispatched, Received}.", - ("State"), (), ray::stats::GAUGE); + ("State"), + (), + ray::stats::GAUGE); DEFINE_stats(scheduler_unscheduleable_tasks, "Number of pending tasks (not scheduleable tasks) broken per reason " "{Infeasible, WaitingForResources, " "WaitingForPlasmaMemory, WaitingForRemoteResources, WaitingForWorkers}.", - ("Reason"), (), ray::stats::GAUGE); + ("Reason"), + (), + ray::stats::GAUGE); DEFINE_stats(scheduler_failed_worker_startup_total, "Number of tasks that fail to be scheduled because workers were not " "available. Labels are broken per reason {JobConfigMissing, " "RegistrationTimedOut, RateLimited}", - ("Reason"), (), ray::stats::GAUGE); + ("Reason"), + (), + ray::stats::GAUGE); /// Local Object Manager DEFINE_stats( spill_manager_objects, "Number of local objects broken per state {Pinned, PendingRestore, PendingSpill}.", - ("State"), (), ray::stats::GAUGE); + ("State"), + (), + ray::stats::GAUGE); DEFINE_stats(spill_manager_objects_bytes, "Byte size of local objects broken per state {Pinned, PendingSpill}.", - ("State"), (), ray::stats::GAUGE); -DEFINE_stats(spill_manager_request_total, "Number of {spill, restore} requests.", - ("Type"), (), ray::stats::GAUGE); + ("State"), + (), + ray::stats::GAUGE); +DEFINE_stats(spill_manager_request_total, + "Number of {spill, restore} requests.", + ("Type"), + (), + ray::stats::GAUGE); DEFINE_stats(spill_manager_throughput_mb, - "The throughput of {spill, restore} requests in MB.", ("Type"), (), + "The throughput of {spill, restore} requests in MB.", + ("Type"), + (), ray::stats::GAUGE); /// GCS Resource Manager DEFINE_stats(gcs_new_resource_creation_latency_ms, - "Time to persist newly created resources to Redis.", (), - ({0.1, 1, 10, 100, 1000, 10000}, ), ray::stats::HISTOGRAM); + "Time to persist newly created resources to Redis.", + (), + ({0.1, 1, 10, 100, 1000, 10000}, ), + ray::stats::HISTOGRAM); /// Placement Group // The end to end placement group creation latency. @@ -121,22 +172,30 @@ DEFINE_stats(gcs_new_resource_creation_latency_ms, // <-> Placement group creation succeeds (meaning all resources // are committed to nodes and available). DEFINE_stats(gcs_placement_group_creation_latency_ms, - "end to end latency of placement group creation", (), - ({0.1, 1, 10, 100, 1000, 10000}, ), ray::stats::HISTOGRAM); + "end to end latency of placement group creation", + (), + ({0.1, 1, 10, 100, 1000, 10000}, ), + ray::stats::HISTOGRAM); // The time from placement group scheduling has started // <-> Placement group creation succeeds. DEFINE_stats(gcs_placement_group_scheduling_latency_ms, - "scheduling latency of placement groups", (), - ({0.1, 1, 10, 100, 1000, 10000}, ), ray::stats::HISTOGRAM); + "scheduling latency of placement groups", + (), + ({0.1, 1, 10, 100, 1000, 10000}, ), + ray::stats::HISTOGRAM); DEFINE_stats(gcs_placement_group_count, "Number of placement groups broken down by state in {Registered, Pending, " "Infeasible}", - ("State"), (), ray::stats::GAUGE); + ("State"), + (), + ray::stats::GAUGE); /// GCS Actor Manager DEFINE_stats(gcs_actors_count, "Number of actors per state {Created, Destroyed, Unresolved, Pending}", - ("State"), (), ray::stats::GAUGE); + ("State"), + (), + ray::stats::GAUGE); } // namespace stats } // namespace ray diff --git a/src/ray/stats/metric_defs.h b/src/ray/stats/metric_defs.h index 24999d46e..6f6d45411 100644 --- a/src/ray/stats/metric_defs.h +++ b/src/ray/stats/metric_defs.h @@ -99,7 +99,8 @@ DECLARE_stats(gcs_actors_count); /// /// RPC static Histogram GcsLatency("gcs_latency", - "The latency of a GCS (by default Redis) operation.", "us", + "The latency of a GCS (by default Redis) operation.", + "us", {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000}, {CustomKey}); @@ -109,25 +110,30 @@ static Histogram GcsLatency("gcs_latency", /// Raylet Resource Manager static Gauge LocalAvailableResource("local_available_resource", - "The available resources on this node.", "", + "The available resources on this node.", + "", {ResourceNameKey}); static Gauge LocalTotalResource("local_total_resource", - "The total resources on this node.", "", + "The total resources on this node.", + "", {ResourceNameKey}); /// Object Manager. static Gauge ObjectStoreAvailableMemory( "object_store_available_memory", - "Amount of memory currently available in the object store.", "bytes"); + "Amount of memory currently available in the object store.", + "bytes"); static Gauge ObjectStoreUsedMemory( "object_store_used_memory", - "Amount of memory currently occupied in the object store.", "bytes"); + "Amount of memory currently occupied in the object store.", + "bytes"); static Gauge ObjectStoreFallbackMemory( "object_store_fallback_memory", - "Amount of memory in fallback allocations in the filesystem.", "bytes"); + "Amount of memory in fallback allocations in the filesystem.", + "bytes"); static Gauge ObjectStoreLocalObjects("object_store_num_local_objects", "Number of objects currently in the object store.", @@ -175,16 +181,19 @@ static Histogram HeartbeatReportMs( "Heartbeat report time in raylet. If this value is high, that means there's a high " "system load. It is possible that this node will be killed because of missing " "heartbeats.", - "ms", {100, 200, 400, 800, 1600, 3200, 6400, 15000, 30000}); + "ms", + {100, 200, 400, 800, 1600, 3200, 6400, 15000, 30000}); /// Worker Pool static Histogram ProcessStartupTimeMs("process_startup_time_ms", - "Time to start up a worker process.", "ms", + "Time to start up a worker process.", + "ms", {1, 10, 100, 1000, 10000}); static Sum NumWorkersStarted( "internal_num_processes_started", - "The total number of worker processes the worker pool has created.", "processes"); + "The total number of worker processes the worker pool has created.", + "processes"); static Sum NumSpilledTasks("internal_num_spilled_tasks", "The cumulative number of lease requeusts that this raylet " @@ -193,7 +202,8 @@ static Sum NumSpilledTasks("internal_num_spilled_tasks", static Gauge NumInfeasibleSchedulingClasses( "internal_num_infeasible_scheduling_classes", - "The number of unique scheduling classes that are infeasible.", "tasks"); + "The number of unique scheduling classes that are infeasible.", + "tasks"); /// /// GCS Server Metrics @@ -208,21 +218,27 @@ static Count UnintentionalWorkerFailures( /// Nodes static Count NodeFailureTotal( - "node_failure_total", "Number of node failures that have happened in the cluster.", + "node_failure_total", + "Number of node failures that have happened in the cluster.", ""); /// Resources static Histogram OutboundHeartbeatSizeKB("outbound_heartbeat_size_kb", - "Outbound heartbeat payload size", "kb", + "Outbound heartbeat payload size", + "kb", {10, 50, 100, 1000, 10000, 100000}); static Histogram GcsUpdateResourceUsageTime( - "gcs_update_resource_usage_time", "The average RTT of a UpdateResourceUsage RPC.", - "ms", {1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000}, {CustomKey}); + "gcs_update_resource_usage_time", + "The average RTT of a UpdateResourceUsage RPC.", + "ms", + {1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000}, + {CustomKey}); /// Testing static Gauge LiveActors("live_actors", "Number of live actors.", "actors"); -static Gauge RestartingActors("restarting_actors", "Number of restarting actors.", +static Gauge RestartingActors("restarting_actors", + "Number of restarting actors.", "actors"); } // namespace stats diff --git a/src/ray/stats/metric_exporter.cc b/src/ray/stats/metric_exporter.cc index 9ebe49e42..f39958f8a 100644 --- a/src/ray/stats/metric_exporter.cc +++ b/src/ray/stats/metric_exporter.cc @@ -24,7 +24,8 @@ void MetricPointExporter::ExportToPoints( const opencensus::stats::ViewData::DataMap &view_data, const opencensus::stats::MeasureDescriptor &measure_descriptor, - std::vector &keys, std::vector &points) { + std::vector &keys, + std::vector &points) { // Return if no raw data found in view map. if (view_data.size() == 0) { return; @@ -55,12 +56,12 @@ void MetricPointExporter::ExportToPoints( } } hist_mean /= view_data.size(); - MetricPoint mean_point = {metric_name + ".mean", current_sys_time_ms(), hist_mean, tags, - measure_descriptor}; - MetricPoint max_point = {metric_name + ".max", current_sys_time_ms(), hist_max, tags, - measure_descriptor}; - MetricPoint min_point = {metric_name + ".min", current_sys_time_ms(), hist_min, tags, - measure_descriptor}; + MetricPoint mean_point = { + metric_name + ".mean", current_sys_time_ms(), hist_mean, tags, measure_descriptor}; + MetricPoint max_point = { + metric_name + ".max", current_sys_time_ms(), hist_max, tags, measure_descriptor}; + MetricPoint min_point = { + metric_name + ".min", current_sys_time_ms(), hist_min, tags, measure_descriptor}; points.push_back(std::move(mean_point)); points.push_back(std::move(max_point)); points.push_back(std::move(min_point)); @@ -94,8 +95,8 @@ void MetricPointExporter::ExportViewData( ExportToPoints(view_data.int_data(), measure_descriptor, keys, points); break; case opencensus::stats::ViewData::Type::kDistribution: - ExportToPoints(view_data.distribution_data(), - measure_descriptor, keys, points); + ExportToPoints( + view_data.distribution_data(), measure_descriptor, keys, points); break; default: RAY_LOG(FATAL) << "Unknown view data type."; diff --git a/src/ray/stats/metric_exporter.h b/src/ray/stats/metric_exporter.h index 8644b5514..cb7cc2bb5 100644 --- a/src/ray/stats/metric_exporter.h +++ b/src/ray/stats/metric_exporter.h @@ -14,6 +14,7 @@ #pragma once #include + #include "absl/memory/memory.h" #include "opencensus/stats/stats.h" #include "opencensus/tags/tag_key.h" @@ -62,7 +63,8 @@ class MetricPointExporter final : public opencensus::stats::StatsExporter::Handl /// \param points, memory metric vector instance void ExportToPoints(const opencensus::stats::ViewData::DataMap &view_data, const opencensus::stats::MeasureDescriptor &measure_descriptor, - std::vector &keys, std::vector &points) { + std::vector &keys, + std::vector &points) { const auto &metric_name = measure_descriptor.name(); for (const auto &row : view_data) { std::unordered_map tags; @@ -70,8 +72,11 @@ class MetricPointExporter final : public opencensus::stats::StatsExporter::Handl tags[keys[i]] = row.first[i]; } // Current timestamp is used for point not view data time. - MetricPoint point{metric_name, current_sys_time_ms(), - static_cast(row.second), tags, measure_descriptor}; + MetricPoint point{metric_name, + current_sys_time_ms(), + static_cast(row.second), + tags, + measure_descriptor}; points.push_back(std::move(point)); if (points.size() >= report_batch_size_) { metric_exporter_client_->ReportMetrics(points); @@ -89,12 +94,14 @@ class MetricPointExporter final : public opencensus::stats::StatsExporter::Handl class OpenCensusProtoExporter final : public opencensus::stats::StatsExporter::Handler { public: - OpenCensusProtoExporter(const int port, instrumented_io_context &io_service, + OpenCensusProtoExporter(const int port, + instrumented_io_context &io_service, const std::string address); ~OpenCensusProtoExporter() = default; - static void Register(const int port, instrumented_io_context &io_service, + static void Register(const int port, + instrumented_io_context &io_service, const std::string address) { opencensus::stats::StatsExporter::RegisterPushHandler( absl::make_unique(port, io_service, address)); diff --git a/src/ray/stats/metric_exporter_client_test.cc b/src/ray/stats/metric_exporter_client_test.cc index 55cea6f9a..fb04acc5c 100644 --- a/src/ray/stats/metric_exporter_client_test.cc +++ b/src/ray/stats/metric_exporter_client_test.cc @@ -166,8 +166,8 @@ TEST_F(MetricExporterClientTest, exporter_client_caculation_test) { for (int i = 0; i < 50; i++) { hist_vector.push_back((double)(i * 10.0)); } - static stats::Histogram random_hist("ray.random.hist", "", "", hist_vector, - {tag1, tag2}); + static stats::Histogram random_hist( + "ray.random.hist", "", "", hist_vector, {tag1, tag2}); for (size_t i = 0; i < 500; ++i) { random_counter.Record(i, {{tag1, std::to_string(i)}, {tag2, std::to_string(i * 2)}}); random_gauge.Record(i, {{tag1, std::to_string(i)}, {tag2, std::to_string(i * 2)}}); diff --git a/src/ray/stats/stats.h b/src/ray/stats/stats.h index d6a135c9c..03ec6ab5f 100644 --- a/src/ray/stats/stats.h +++ b/src/ray/stats/stats.h @@ -52,7 +52,8 @@ static absl::Mutex stats_mutex; /// \param global_tags[in] Tags that will be appended to all metrics in this process. /// \param metrics_agent_port[in] The port to export metrics at each node. /// \param exporter_to_use[in] The exporter client you will use for this process' metrics. -static inline void Init(const TagsType &global_tags, const int metrics_agent_port, +static inline void Init(const TagsType &global_tags, + const int metrics_agent_port, std::shared_ptr exporter_to_use = nullptr, int64_t metrics_report_batch_size = RayConfig::instance().metrics_report_batch_size()) { @@ -94,8 +95,8 @@ static inline void Init(const TagsType &global_tags, const int metrics_agent_por static_cast(500)))); MetricPointExporter::Register(exporter, metrics_report_batch_size); - OpenCensusProtoExporter::Register(metrics_agent_port, (*metrics_io_service), - "127.0.0.1"); + OpenCensusProtoExporter::Register( + metrics_agent_port, (*metrics_io_service), "127.0.0.1"); opencensus::stats::StatsExporter::SetInterval( StatsConfig::instance().GetReportInterval()); opencensus::stats::DeltaProducer::Get()->SetHarvestInterval( diff --git a/src/ray/stats/stats_test.cc b/src/ray/stats/stats_test.cc index 183c64624..f60083945 100644 --- a/src/ray/stats/stats_test.cc +++ b/src/ray/stats/stats_test.cc @@ -24,13 +24,20 @@ #include "gtest/gtest.h" #include "ray/stats/metric_defs.h" -DEFINE_stats(test_hist, "TestStats", ("method", "method2"), (1.0, 2.0, 3.0, 4.0), +DEFINE_stats(test_hist, + "TestStats", + ("method", "method2"), + (1.0, 2.0, 3.0, 4.0), ray::stats::HISTOGRAM); -DEFINE_stats(test_2, "TestStats", ("method", "method2"), (1.0), ray::stats::COUNT, +DEFINE_stats(test_2, + "TestStats", + ("method", "method2"), + (1.0), + ray::stats::COUNT, ray::stats::SUM); DEFINE_stats(test, "TestStats", ("method"), (1.0), ray::stats::COUNT, ray::stats::SUM); -DEFINE_stats(test_declare, "TestStats2", ("tag1"), (1.0), ray::stats::COUNT, - ray::stats::SUM); +DEFINE_stats( + test_declare, "TestStats2", ("tag1"), (1.0), ray::stats::COUNT, ray::stats::SUM); DECLARE_stats(test_declare); namespace ray { @@ -105,7 +112,8 @@ TEST_F(StatsTest, InitializationTest) { std::shared_ptr exporter( new stats::StdoutExporterClient()); ray::stats::Init({{stats::LanguageKey, test_tag_value_that_shouldnt_be_applied}}, - MetricsAgentPort, exporter); + MetricsAgentPort, + exporter); } auto &first_tag = ray::stats::StatsConfig::instance().GetGlobalTags()[0]; @@ -138,15 +146,21 @@ TEST(Metric, MultiThreadMetricRegisterViewTest) { threads.emplace_back([tag1, tag2, index]() { for (int i = 0; i < 100; i++) { stats::Count random_counter( - "ray.random.counter" + std::to_string(index) + std::to_string(i), "", "", + "ray.random.counter" + std::to_string(index) + std::to_string(i), + "", + "", {tag1, tag2}); random_counter.Record(i); stats::Gauge random_gauge( - "ray.random.gauge" + std::to_string(index) + std::to_string(i), "", "", + "ray.random.gauge" + std::to_string(index) + std::to_string(i), + "", + "", {tag1, tag2}); random_gauge.Record(i); stats::Sum random_sum( - "ray.random.sum" + std::to_string(index) + std::to_string(i), "", "", + "ray.random.sum" + std::to_string(index) + std::to_string(i), + "", + "", {tag1, tag2}); random_sum.Record(i); } diff --git a/src/ray/util/event.cc b/src/ray/util/event.cc index 2d167694d..7523f2957 100644 --- a/src/ray/util/event.cc +++ b/src/ray/util/event.cc @@ -24,8 +24,10 @@ namespace ray { /// LogEventReporter /// LogEventReporter::LogEventReporter(rpc::Event_SourceType source_type, - const std::string &log_dir, bool force_flush, - int rotate_max_file_size, int rotate_max_file_num) + const std::string &log_dir, + bool force_flush, + int rotate_max_file_size, + int rotate_max_file_num) : log_dir_(log_dir), force_flush_(force_flush), rotate_max_file_size_(rotate_max_file_size), @@ -51,9 +53,10 @@ LogEventReporter::LogEventReporter(rpc::Event_SourceType source_type, // for example event_GCS.0.log, event_GCS.1.log, event_GCS.2.log ... // We alow to rotate for {rotate_max_file_num_} times. if (log_sink_ == nullptr) { - log_sink_ = - spdlog::rotating_logger_mt(log_sink_key, log_dir_ + file_name_, - 1048576 * rotate_max_file_size_, rotate_max_file_num_); + log_sink_ = spdlog::rotating_logger_mt(log_sink_key, + log_dir_ + file_name_, + 1048576 * rotate_max_file_size_, + rotate_max_file_num_); } log_sink_->set_pattern("%v"); } @@ -84,8 +87,8 @@ std::string LogEventReporter::EventToString(const rpc::Event &event, absl::Time absl_time = absl::FromTimeT(epoch_time_as_time_t); std::stringstream time_stamp_buffer; - time_stamp_buffer << absl::FormatTime("%Y-%m-%d %H:%M:%S.", absl_time, - absl::LocalTimeZone()) + time_stamp_buffer << absl::FormatTime( + "%Y-%m-%d %H:%M:%S.", absl_time, absl::LocalTimeZone()) << std::setw(6) << std::setfill('0') << time_stamp % 1000000; j["time_stamp"] = time_stamp_buffer.str(); @@ -229,14 +232,16 @@ static void SetEventLevel(const std::string &event_level) { RAY_LOG(INFO) << "Set ray event level to " << level; } -void RayEvent::ReportEvent(const std::string &severity, const std::string &label, - const std::string &message, const char *file_name, +void RayEvent::ReportEvent(const std::string &severity, + const std::string &label, + const std::string &message, + const char *file_name, int line_number) { rpc::Event_Severity severity_ele = rpc::Event_Severity::Event_Severity_Event_Severity_INT_MIN_SENTINEL_DO_NOT_USE_; RAY_CHECK(rpc::Event_Severity_Parse(severity, &severity_ele)); - RayEvent(severity_ele, EventLevelToLogLevel(severity_ele), label, file_name, - line_number) + RayEvent( + severity_ele, EventLevelToLogLevel(severity_ele), label, file_name, line_number) << message; } @@ -316,7 +321,8 @@ static absl::once_flag init_once_; void RayEventInit(rpc::Event_SourceType source_type, const absl::flat_hash_map &custom_fields, - const std::string &log_dir, const std::string &event_level) { + const std::string &log_dir, + const std::string &event_level) { absl::call_once(init_once_, [&source_type, &custom_fields, &log_dir, &event_level]() { RayEventContext::Instance().SetEventContext(source_type, custom_fields); auto event_dir = boost::filesystem::path(log_dir) / boost::filesystem::path("events"); diff --git a/src/ray/util/event.h b/src/ray/util/event.h index ac9842ea8..cc88c9eaa 100644 --- a/src/ray/util/event.h +++ b/src/ray/util/event.h @@ -46,7 +46,9 @@ namespace ray { ::ray::RayEvent(::ray::rpc::Event_Severity::Event_Severity_##event_type, \ ray::RayEvent::EventLevelToLogLevel( \ ::ray::rpc::Event_Severity::Event_Severity_##event_type), \ - label, __FILE__, __LINE__) + label, \ + __FILE__, \ + __LINE__) // interface of event reporter class BaseEventReporter { @@ -62,8 +64,10 @@ class BaseEventReporter { // responsible for writing event to specific file class LogEventReporter : public BaseEventReporter { public: - LogEventReporter(rpc::Event_SourceType source_type, const std::string &log_dir, - bool force_flush = true, int rotate_max_file_size = 100, + LogEventReporter(rpc::Event_SourceType source_type, + const std::string &log_dir, + bool force_flush = true, + int rotate_max_file_size = 100, int rotate_max_file_num = 20); virtual ~LogEventReporter(); @@ -198,8 +202,11 @@ class RayEvent { public: // We require file_name to be a string which has static storage before RayEvent // deconstructed. Otherwise we might have memory issues. - RayEvent(rpc::Event_Severity severity, RayLogLevel log_severity, - const std::string &label, const char *file_name, int line_number) + RayEvent(rpc::Event_Severity severity, + RayLogLevel log_severity, + const std::string &label, + const char *file_name, + int line_number) : severity_(severity), log_severity_(log_severity), label_(label), @@ -221,8 +228,10 @@ class RayEvent { return *this; } - static void ReportEvent(const std::string &severity, const std::string &label, - const std::string &message, const char *file_name, + static void ReportEvent(const std::string &severity, + const std::string &label, + const std::string &message, + const char *file_name, int line_number); /// Return whether or not the event level is enabled in current setting. @@ -273,6 +282,7 @@ class RayEvent { /// \return void. void RayEventInit(rpc::Event_SourceType source_type, const absl::flat_hash_map &custom_fields, - const std::string &log_dir, const std::string &event_level = "warning"); + const std::string &log_dir, + const std::string &event_level = "warning"); } // namespace ray diff --git a/src/ray/util/event_test.cc b/src/ray/util/event_test.cc index 2d6d78f05..031138c9e 100644 --- a/src/ray/util/event_test.cc +++ b/src/ray/util/event_test.cc @@ -43,9 +43,14 @@ class TestEventReporter : public BaseEventReporter { std::vector TestEventReporter::event_list = std::vector(); -void CheckEventDetail(rpc::Event &event, std::string job_id, std::string node_id, - std::string task_id, std::string source_type, std::string severity, - std::string label, std::string message) { +void CheckEventDetail(rpc::Event &event, + std::string job_id, + std::string node_id, + std::string task_id, + std::string source_type, + std::string severity, + std::string label, + std::string message) { int custom_key_num = 0; auto mp = (*event.mutable_custom_fields()); @@ -118,7 +123,8 @@ rpc::Event GetEventFromString(std::string seq, json *custom_fields) { return event; } -void ParallelRunning(int nthreads, int loop_times, +void ParallelRunning(int nthreads, + int loop_times, std::function event_context_init, std::function loop_function) { if (nthreads > 1) { @@ -132,7 +138,8 @@ void ParallelRunning(int nthreads, int loop_times, } }, t * loop_times / nthreads, - (t + 1) == nthreads ? loop_times : (t + 1) * loop_times / nthreads, t)); + (t + 1) == nthreads ? loop_times : (t + 1) * loop_times / nthreads, + t)); } std::for_each(threads.begin(), threads.end(), [](std::thread &x) { x.join(); }); } else { @@ -143,7 +150,8 @@ void ParallelRunning(int nthreads, int loop_times, } } -void ReadContentFromFile(std::vector &vc, std::string log_file, +void ReadContentFromFile(std::vector &vc, + std::string log_file, std::string filter = "") { std::string line; std::ifstream read_file; @@ -213,13 +221,25 @@ TEST_F(EventTest, TestBasic) { EXPECT_EQ(result.size(), 4); - CheckEventDetail(result[0], "", "", "", "COMMON", "WARNING", "label 0", - "send message 0"); + CheckEventDetail( + result[0], "", "", "", "COMMON", "WARNING", "label 0", "send message 0"); - CheckEventDetail(result[1], "job 1", "node 1", "task 1", "CORE_WORKER", "INFO", - "label 1", "send message 1"); + CheckEventDetail(result[1], + "job 1", + "node 1", + "task 1", + "CORE_WORKER", + "INFO", + "label 1", + "send message 1"); - CheckEventDetail(result[2], "job 2", "node 2", "", "RAYLET", "ERROR", "label 2", + CheckEventDetail(result[2], + "job 2", + "node 2", + "", + "RAYLET", + "ERROR", + "label 2", "send message 2 send message again"); CheckEventDetail(result[3], "", "", "", "GCS", "FATAL", "", ""); @@ -247,11 +267,23 @@ TEST_F(EventTest, TestUpdateCustomFields) { EXPECT_EQ(result.size(), 2); - CheckEventDetail(result[0], "job 1", "node 1", "", "CORE_WORKER", "INFO", "label 1", + CheckEventDetail(result[0], + "job 1", + "node 1", + "", + "CORE_WORKER", + "INFO", + "label 1", "send message 1"); - CheckEventDetail(result[1], "job 2", "node 1", "task 2", "CORE_WORKER", "ERROR", - "label 2", "send message 2 send message again"); + CheckEventDetail(result[1], + "job 2", + "node 1", + "task 2", + "CORE_WORKER", + "ERROR", + "label 2", + "send message 2 send message again"); } TEST_F(EventTest, TestLogOneThread) { @@ -276,7 +308,12 @@ TEST_F(EventTest, TestLogOneThread) { for (int i = 0, len = vc.size(); i < print_times; ++i) { json custom_fields; rpc::Event ele = GetEventFromString(vc[len - print_times + i], &custom_fields); - CheckEventDetail(ele, "job 1", "node 1", "task 1", "RAYLET", "INFO", + CheckEventDetail(ele, + "job 1", + "node 1", + "task 1", + "RAYLET", + "INFO", "label " + std::to_string(i + 1), "send message " + std::to_string(i + 1)); } @@ -304,10 +341,10 @@ TEST_F(EventTest, TestMultiThreadContextCopy) { EXPECT_EQ(result.size(), 3); CheckEventDetail(result[0], "", "", "", "COMMON", "INFO", "label 0", "send message 0"); - CheckEventDetail(result[1], "job 1", "node 1", "task 1", "GCS", "INFO", "label 2", - "send message 2"); - CheckEventDetail(result[2], "job 1", "node 1", "task 1", "GCS", "INFO", "label 1", - "send message 1"); + CheckEventDetail( + result[1], "job 1", "node 1", "task 1", "GCS", "INFO", "label 2", "send message 2"); + CheckEventDetail( + result[2], "job 1", "node 1", "task 1", "GCS", "INFO", "label 1", "send message 1"); ray::RayEventContext::Instance().ResetEventContext(); TestEventReporter::event_list.clear(); @@ -324,8 +361,8 @@ TEST_F(EventTest, TestMultiThreadContextCopy) { RAY_EVENT(INFO, "label 3") << "send message 3"; EXPECT_EQ(result.size(), 2); - CheckEventDetail(result[0], "job 1", "", "", "RAYLET", "INFO", "label 2", - "send message 2"); + CheckEventDetail( + result[0], "job 1", "", "", "RAYLET", "INFO", "label 2", "send message 2"); CheckEventDetail(result[1], "", "", "", "COMMON", "INFO", "label 3", "send message 3"); } @@ -336,7 +373,8 @@ TEST_F(EventTest, TestLogMultiThread) { int print_times = 1000; ParallelRunning( - nthreads, print_times, + nthreads, + print_times, []() { RayEventContext::Instance().SetEventContext( rpc::Event_SourceType::Event_SourceType_GCS, @@ -418,8 +456,8 @@ TEST_F(EventTest, TestWithField) { json custom_fields; rpc::Event ele = GetEventFromString(vc[0], &custom_fields); - CheckEventDetail(ele, "job 1", "node 1", "task 1", "RAYLET", "INFO", "label 1", - "send message 1"); + CheckEventDetail( + ele, "job 1", "node 1", "task 1", "RAYLET", "INFO", "label 1", "send message 1"); auto string_value = custom_fields["string"].get(); EXPECT_EQ(string_value, "test string"); auto int_value = custom_fields["int"].get(); @@ -449,8 +487,14 @@ TEST_F(EventTest, TestRayCheckAbort) { json out_custom_fields; rpc::Event ele_1 = GetEventFromString(vc.back(), &out_custom_fields); - CheckEventDetail(ele_1, "job 1", "node 1", "task 1", "RAYLET", "FATAL", - EL_RAY_FATAL_CHECK_FAILED, "NULL"); + CheckEventDetail(ele_1, + "job 1", + "node 1", + "task 1", + "RAYLET", + "FATAL", + EL_RAY_FATAL_CHECK_FAILED, + "NULL"); EXPECT_THAT(ele_1.message(), testing::HasSubstr("Check failed: 1 < 0 incorrect test case")); EXPECT_THAT(ele_1.message(), testing::HasSubstr("*** StackTrace Information ***")); @@ -472,8 +516,8 @@ TEST_F(EventTest, TestRayEventInit) { json out_custom_fields; rpc::Event ele_1 = GetEventFromString(vc.back(), &out_custom_fields); - CheckEventDetail(ele_1, "job 1", "node 1", "task 1", "RAYLET", "FATAL", "label", - "NULL"); + CheckEventDetail( + ele_1, "job 1", "node 1", "task 1", "RAYLET", "FATAL", "label", "NULL"); } TEST_F(EventTest, TestLogLevel) { @@ -492,8 +536,8 @@ TEST_F(EventTest, TestLogLevel) { std::vector &result = TestEventReporter::event_list; EXPECT_EQ(result.size(), 4); CheckEventDetail(result[0], "", "", "", "CORE_WORKER", "INFO", "label", "test info"); - CheckEventDetail(result[1], "", "", "", "CORE_WORKER", "WARNING", "label", - "test warning"); + CheckEventDetail( + result[1], "", "", "", "CORE_WORKER", "WARNING", "label", "test warning"); CheckEventDetail(result[2], "", "", "", "CORE_WORKER", "ERROR", "label", "test error"); CheckEventDetail(result[3], "", "", "", "CORE_WORKER", "FATAL", "label", "test fatal"); result.clear(); @@ -506,8 +550,8 @@ TEST_F(EventTest, TestLogLevel) { RAY_EVENT(FATAL, "label") << "test fatal"; EXPECT_EQ(result.size(), 3); - CheckEventDetail(result[0], "", "", "", "CORE_WORKER", "WARNING", "label", - "test warning"); + CheckEventDetail( + result[0], "", "", "", "CORE_WORKER", "WARNING", "label", "test warning"); CheckEventDetail(result[1], "", "", "", "CORE_WORKER", "ERROR", "label", "test error"); CheckEventDetail(result[2], "", "", "", "CORE_WORKER", "FATAL", "label", "test fatal"); result.clear(); @@ -551,8 +595,8 @@ TEST_F(EventTest, TestLogEvent) { RAY_EVENT(FATAL, "label") << "test fatal"; std::vector vc; - ReadContentFromFile(vc, log_dir + "/event_test_" + std::to_string(getpid()) + ".log", - "[ Event "); + ReadContentFromFile( + vc, log_dir + "/event_test_" + std::to_string(getpid()) + ".log", "[ Event "); EXPECT_EQ((int)vc.size(), 2); // Check ERROR event EXPECT_THAT(vc[0], testing::HasSubstr(" E ")); @@ -576,8 +620,8 @@ TEST_F(EventTest, TestLogEvent) { RAY_EVENT(FATAL, "label") << "test fatal 2"; vc.clear(); - ReadContentFromFile(vc, log_dir + "/event_test_" + std::to_string(getpid()) + ".log", - "[ Event "); + ReadContentFromFile( + vc, log_dir + "/event_test_" + std::to_string(getpid()) + ".log", "[ Event "); EXPECT_EQ((int)vc.size(), 4); // Check INFO event EXPECT_THAT(vc[0], testing::HasSubstr(" I ")); diff --git a/src/ray/util/logging.cc b/src/ray/util/logging.cc index 3953e570e..7a3128c8c 100644 --- a/src/ray/util/logging.cc +++ b/src/ray/util/logging.cc @@ -104,7 +104,9 @@ class DefaultStdErrLogger final { class SpdLogMessage final { public: - explicit SpdLogMessage(const char *file, int line, int loglevel, + explicit SpdLogMessage(const char *file, + int line, + int loglevel, std::shared_ptr expose_osstream) : loglevel_(loglevel), expose_osstream_(expose_osstream) { stream() << ConstBasename(file) << ":" << line << ": "; @@ -123,8 +125,8 @@ class SpdLogMessage final { *expose_osstream_ << "\n*** StackTrace Information ***\n" << ray::GetCallTrace(); } // NOTE(lingxuan.zlx): See more fmt by visiting https://github.com/fmtlib/fmt. - logger->log(static_cast(loglevel_), /*fmt*/ "{}", - str_.str()); + logger->log( + static_cast(loglevel_), /*fmt*/ "{}", str_.str()); logger->flush(); } @@ -167,7 +169,8 @@ static int GetMappedSeverity(RayLogLevel severity) { std::vector RayLog::fatal_log_callbacks_; -void RayLog::StartRayLog(const std::string &app_name, RayLogLevel severity_threshold, +void RayLog::StartRayLog(const std::string &app_name, + RayLogLevel severity_threshold, const std::string &log_dir) { const char *var_value = std::getenv("RAY_BACKEND_LOG_LEVEL"); if (var_value != nullptr) { @@ -244,7 +247,8 @@ void RayLog::StartRayLog(const std::string &app_name, RayLogLevel severity_thres } auto file_sink = std::make_shared( JoinPaths(log_dir_, app_name_without_path + "_" + std::to_string(pid) + ".log"), - log_rotation_max_size_, log_rotation_file_num_); + log_rotation_max_size_, + log_rotation_file_num_); sinks.push_back(file_sink); } else { // Format pattern is 2020-08-21 17:00:00,000 I 100 1001 msg. @@ -265,8 +269,8 @@ void RayLog::StartRayLog(const std::string &app_name, RayLogLevel severity_thres sinks.push_back(err_sink); // Set the combined logger. - auto logger = std::make_shared(RayLog::GetLoggerName(), sinks.begin(), - sinks.end()); + auto logger = std::make_shared( + RayLog::GetLoggerName(), sinks.begin(), sinks.end()); logger->set_level(level); logger->set_pattern(log_format_pattern_); spdlog::set_level(static_cast(severity_threshold_)); @@ -361,7 +365,8 @@ std::string RayLog::GetLoggerName() { return logger_name_; } void RayLog::AddFatalLogCallbacks( const std::vector &expose_log_callbacks) { - fatal_log_callbacks_.insert(fatal_log_callbacks_.end(), expose_log_callbacks.begin(), + fatal_log_callbacks_.insert(fatal_log_callbacks_.end(), + expose_log_callbacks.begin(), expose_log_callbacks.end()); } diff --git a/src/ray/util/logging_test.cc b/src/ray/util/logging_test.cc index 3445803a1..c4c145fe0 100644 --- a/src/ray/util/logging_test.cc +++ b/src/ray/util/logging_test.cc @@ -197,8 +197,8 @@ TEST(PrintLogTest, LogTestWithInit) { // This test will output large amount of logs to stderr, should be disabled in travis. TEST(LogPerfTest, PerfTest) { - RayLog::StartRayLog("/fake/path/to/appdire/LogPerfTest", RayLogLevel::ERROR, - ray::GetUserTempDir()); + RayLog::StartRayLog( + "/fake/path/to/appdire/LogPerfTest", RayLogLevel::ERROR, ray::GetUserTempDir()); int rounds = 10; int64_t start_time = current_time_ms(); diff --git a/src/ray/util/memory.cc b/src/ray/util/memory.cc index 699b65efa..9cbb97d20 100644 --- a/src/ray/util/memory.cc +++ b/src/ray/util/memory.cc @@ -25,8 +25,11 @@ uint8_t *pointer_logical_and(const uint8_t *address, uintptr_t bits) { return reinterpret_cast(value & bits); } -void parallel_memcopy(uint8_t *dst, const uint8_t *src, int64_t nbytes, - uintptr_t block_size, int num_threads) { +void parallel_memcopy(uint8_t *dst, + const uint8_t *src, + int64_t nbytes, + uintptr_t block_size, + int num_threads) { std::vector threadpool(num_threads); uint8_t *left = pointer_logical_and(src + block_size - 1, ~(block_size - 1)); uint8_t *right = pointer_logical_and(src + nbytes, ~(block_size - 1)); @@ -47,8 +50,8 @@ void parallel_memcopy(uint8_t *dst, const uint8_t *src, int64_t nbytes, // Start all threads first and handle leftovers while threads run. for (int i = 0; i < num_threads; i++) { - threadpool[i] = std::thread(std::memcpy, dst + prefix + i * chunk_size, - left + i * chunk_size, chunk_size); + threadpool[i] = std::thread( + std::memcpy, dst + prefix + i * chunk_size, left + i * chunk_size, chunk_size); } std::memcpy(dst, src, prefix); diff --git a/src/ray/util/memory.h b/src/ray/util/memory.h index d1322800b..9eebe5448 100644 --- a/src/ray/util/memory.h +++ b/src/ray/util/memory.h @@ -20,7 +20,10 @@ namespace ray { // A helper function for doing memcpy with multiple threads. This is required // to saturate the memory bandwidth of modern cpus. -void parallel_memcopy(uint8_t *dst, const uint8_t *src, int64_t nbytes, - uintptr_t block_size, int num_threads); +void parallel_memcopy(uint8_t *dst, + const uint8_t *src, + int64_t nbytes, + uintptr_t block_size, + int num_threads); } // namespace ray diff --git a/src/ray/util/process.cc b/src/ray/util/process.cc index 69be1158c..8d634dff9 100644 --- a/src/ray/util/process.cc +++ b/src/ray/util/process.cc @@ -98,7 +98,9 @@ class ProcessFD { pid_t GetId() const; // Fork + exec combo. Returns -1 for the PID on failure. - static ProcessFD spawnvpe(const char *argv[], std::error_code &ec, bool decouple, + static ProcessFD spawnvpe(const char *argv[], + std::error_code &ec, + bool decouple, const ProcessEnvironment &env) { ec = std::error_code(); intptr_t fd; @@ -191,8 +193,8 @@ class ProcessFD { // This is the spawned process. Any intermediate parent is now dead. pid_t my_pid = getpid(); if (write(pipefds[1], &my_pid, sizeof(my_pid)) == sizeof(my_pid)) { - execvpe(argv[0], const_cast(argv), - const_cast(envp)); + execvpe( + argv[0], const_cast(argv), const_cast(envp)); } _exit(errno); // fork() succeeded and exec() failed, so abort the child } @@ -303,8 +305,12 @@ intptr_t ProcessFD::CloneFD() const { #ifdef _WIN32 HANDLE handle; BOOL inheritable = FALSE; - fd = DuplicateHandle(GetCurrentProcess(), reinterpret_cast(fd_), - GetCurrentProcess(), &handle, 0, inheritable, + fd = DuplicateHandle(GetCurrentProcess(), + reinterpret_cast(fd_), + GetCurrentProcess(), + &handle, + 0, + inheritable, DUPLICATE_SAME_ACCESS) ? reinterpret_cast(handle) : -1; @@ -339,7 +345,10 @@ Process &Process::operator=(Process other) { Process::Process(pid_t pid) { p_ = std::make_shared(pid); } -Process::Process(const char *argv[], void *io_service, std::error_code &ec, bool decouple, +Process::Process(const char *argv[], + void *io_service, + std::error_code &ec, + bool decouple, const ProcessEnvironment &env) { (void)io_service; ProcessFD procfd = ProcessFD::spawnvpe(argv, ec, decouple, env); @@ -555,10 +564,16 @@ pid_t GetParentPID() { if (HANDLE parent = OpenProcess(PROCESS_QUERY_INFORMATION, FALSE, ppid)) { long long me_created, parent_created; FILETIME unused; - if (GetProcessTimes(GetCurrentProcess(), reinterpret_cast(&me_created), - &unused, &unused, &unused) && - GetProcessTimes(parent, reinterpret_cast(&parent_created), &unused, - &unused, &unused)) { + if (GetProcessTimes(GetCurrentProcess(), + reinterpret_cast(&me_created), + &unused, + &unused, + &unused) && + GetProcessTimes(parent, + reinterpret_cast(&parent_created), + &unused, + &unused, + &unused)) { if (me_created >= parent_created) { // We verified the child is younger than the parent, so we know the parent // is still alive. diff --git a/src/ray/util/process.h b/src/ray/util/process.h index 82edf3dec..8771326e5 100644 --- a/src/ray/util/process.h +++ b/src/ray/util/process.h @@ -73,8 +73,11 @@ class Process { /// \param[in] decouple True iff the parent will not wait for the child to exit. /// \param[in] env Additional environment variables to be set on this process besides /// the environment variables of the parent process. - explicit Process(const char *argv[], void *io_service, std::error_code &ec, - bool decouple = false, const ProcessEnvironment &env = {}); + explicit Process(const char *argv[], + void *io_service, + std::error_code &ec, + bool decouple = false, + const ProcessEnvironment &env = {}); /// Convenience function to run the given command line and wait for it to finish. static std::error_code Call(const std::vector &args, const ProcessEnvironment &env = {}); @@ -93,8 +96,10 @@ class Process { /// Convenience function to start a process in the background. /// \param pid_file A file to write the PID of the spawned process in. static std::pair Spawn( - const std::vector &args, bool decouple, - const std::string &pid_file = std::string(), const ProcessEnvironment &env = {}); + const std::vector &args, + bool decouple, + const std::string &pid_file = std::string(), + const ProcessEnvironment &env = {}); /// Waits for process to terminate. Not supported for unowned processes. /// \return The process's exit code. Returns 0 for a dummy process, -1 for a null one. int Wait() const; diff --git a/src/ray/util/sample.h b/src/ray/util/sample.h index 6c6861ff0..e963a179e 100644 --- a/src/ray/util/sample.h +++ b/src/ray/util/sample.h @@ -21,7 +21,9 @@ // Randomly samples num_elements from the elements between first and last using reservoir // sampling. template ::value_type> -void random_sample(Iterator begin, Iterator end, size_t num_elements, +void random_sample(Iterator begin, + Iterator end, + size_t num_elements, std::vector *out) { out->resize(0); if (num_elements == 0) { diff --git a/src/ray/util/sample_test.cc b/src/ray/util/sample_test.cc index a68791114..ed27b4890 100644 --- a/src/ray/util/sample_test.cc +++ b/src/ray/util/sample_test.cc @@ -44,8 +44,8 @@ TEST_F(RandomSampleTest, TestEmpty) { } TEST_F(RandomSampleTest, TestSmallerThanSampleSize) { - random_sample(test_vector->begin(), test_vector->end(), test_vector->size() + 1, - sample); + random_sample( + test_vector->begin(), test_vector->end(), test_vector->size() + 1, sample); ASSERT_EQ(sample->size(), test_vector->size()); } @@ -55,8 +55,8 @@ TEST_F(RandomSampleTest, TestEqualToSampleSize) { } TEST_F(RandomSampleTest, TestLargerThanSampleSize) { - random_sample(test_vector->begin(), test_vector->end(), test_vector->size() - 1, - sample); + random_sample( + test_vector->begin(), test_vector->end(), test_vector->size() - 1, sample); ASSERT_EQ(sample->size(), test_vector->size() - 1); } @@ -64,8 +64,8 @@ TEST_F(RandomSampleTest, TestEqualOccurrenceChance) { int trials = 1000000; std::vector occurrences(test_vector->size(), 0); for (int i = 0; i < trials; i++) { - random_sample(test_vector->begin(), test_vector->end(), test_vector->size() / 2, - sample); + random_sample( + test_vector->begin(), test_vector->end(), test_vector->size() / 2, sample); for (int idx : *sample) { occurrences[idx]++; } diff --git a/src/ray/util/sequencer.h b/src/ray/util/sequencer.h index 3fc622652..ffa52a1b9 100644 --- a/src/ray/util/sequencer.h +++ b/src/ray/util/sequencer.h @@ -72,7 +72,8 @@ class Sequencer { absl::Mutex mutex_; absl::flat_hash_map< - KEY, std::deque>> + KEY, + std::deque>> pending_operations_ GUARDED_BY(mutex_); }; diff --git a/src/ray/util/signal_test.cc b/src/ray/util/signal_test.cc index 9342a8008..c154cdaf0 100644 --- a/src/ray/util/signal_test.cc +++ b/src/ray/util/signal_test.cc @@ -98,7 +98,8 @@ TEST(SignalTest, SIGILL_Test) { int main(int argc, char **argv) { InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLog::ShutDownRayLog, + argv[0], ray::RayLogLevel::INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(argv[0]); diff --git a/src/ray/util/throttler.h b/src/ray/util/throttler.h index a8ef38f33..f0e6ed020 100644 --- a/src/ray/util/throttler.h +++ b/src/ray/util/throttler.h @@ -16,6 +16,7 @@ #include #include + #include "absl/time/clock.h" namespace ray { diff --git a/src/ray/util/throttler_test.cc b/src/ray/util/throttler_test.cc index fc22e97b2..b16b8c5ad 100644 --- a/src/ray/util/throttler_test.cc +++ b/src/ray/util/throttler_test.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gtest/gtest.h" +#include "ray/util/throttler.h" #include #include -#include "ray/util/throttler.h" +#include "gtest/gtest.h" TEST(ThrottlerTest, BasicTest) { int64_t now = 100; diff --git a/src/ray/util/util.h b/src/ray/util/util.h index 66feb917b..7cd8650c1 100644 --- a/src/ray/util/util.h +++ b/src/ray/util/util.h @@ -304,7 +304,8 @@ class ExponentialBackOff { /// \param[in] multiplier The multiplier for this counter. /// \param[in] max_value The maximum value for this counter. By default it's /// infinite double. - ExponentialBackOff(uint64_t initial_value, double multiplier, + ExponentialBackOff(uint64_t initial_value, + double multiplier, uint64_t max_value = std::numeric_limits::max()) : curr_value_(initial_value), initial_value_(initial_value),