""" Progress Bar for Ray Actors (tqdm) ================================== Tracking progress of distributed tasks can be tricky. This script will demonstrate how to implement a simple progress bar for a Ray actor to track progress across various different distributed components. Original source: `Link `_ Setup: Dependencies ------------------- First, import some dependencies. """ # Inspiration: https://github.com/honnibal/spacy-ray/pull/ # 1/files#diff-7ede881ddc3e8456b320afb958362b2aR12-R45 from asyncio import Event from typing import Tuple from time import sleep import ray # For typing purposes from ray.actor import ActorHandle from tqdm import tqdm ############################################################ # This is the Ray "actor" that can be called from anywhere to update # our progress. You'll be using the `update` method. Don't # instantiate this class yourself. Instead, # it's something that you'll get from a `ProgressBar`. @ray.remote class ProgressBarActor: counter: int delta: int event: Event def __init__(self) -> None: self.counter = 0 self.delta = 0 self.event = Event() def update(self, num_items_completed: int) -> None: """Updates the ProgressBar with the incremental number of items that were just completed. """ self.counter += num_items_completed self.delta += num_items_completed self.event.set() async def wait_for_update(self) -> Tuple[int, int]: """Blocking call. Waits until somebody calls `update`, then returns a tuple of the number of updates since the last call to `wait_for_update`, and the total number of completed items. """ await self.event.wait() self.event.clear() saved_delta = self.delta self.delta = 0 return saved_delta, self.counter def get_counter(self) -> int: """ Returns the total number of complete items. """ return self.counter ###################################################################### # This is where the progress bar starts. You create one of these # on the head node, passing in the expected total number of items, # and an optional string description. # Pass along the `actor` reference to any remote task, # and if they complete ten # tasks, they'll call `actor.update.remote(10)`. # Back on the local node, once you launch your remote Ray tasks, call # `print_until_done`, which will feed everything back into a `tqdm` counter. class ProgressBar: progress_actor: ActorHandle total: int description: str pbar: tqdm def __init__(self, total: int, description: str = ""): # Ray actors don't seem to play nice with mypy, generating # a spurious warning for the following line, # which we need to suppress. The code is fine. self.progress_actor = ProgressBarActor.remote() # type: ignore self.total = total self.description = description @property def actor(self) -> ActorHandle: """Returns a reference to the remote `ProgressBarActor`. When you complete tasks, call `update` on the actor. """ return self.progress_actor def print_until_done(self) -> None: """Blocking call. Do this after starting a series of remote Ray tasks, to which you've passed the actor handle. Each of them calls `update` on the actor. When the progress meter reaches 100%, this method returns. """ pbar = tqdm(desc=self.description, total=self.total) while True: delta, counter = ray.get(self.actor.wait_for_update.remote()) pbar.update(delta) if counter >= self.total: pbar.close() return ################################################################# # This is an example of a task that increments the progress bar. # Note that this is a Ray Task, but it could very well # be any generic Ray Actor. # @ray.remote def sleep_then_increment(i: int, pba: ActorHandle) -> int: sleep(i / 2.0) pba.update.remote(1) return i ################################################################# # Now you can run it and see what happens! # def run(): ray.init() num_ticks = 6 pb = ProgressBar(num_ticks) actor = pb.actor # You can replace this with any arbitrary Ray task/actor. tasks_pre_launch = [ sleep_then_increment.remote(i, actor) for i in range(0, num_ticks) ] pb.print_until_done() tasks = ray.get(tasks_pre_launch) tasks == list(range(num_ticks)) num_ticks == ray.get(actor.get_counter.remote()) run()