from typing import Dict from ray.rllib.env import BaseEnv from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.types import AgentID, PolicyID @PublicAPI class DefaultCallbacks: """Abstract base class for RLlib callbacks (similar to Keras callbacks). These callbacks can be used for custom metrics and custom postprocessing. By default, all of these callbacks are no-ops. To configure custom training callbacks, subclass DefaultCallbacks and then set {"callbacks": YourCallbacksClass} in the trainer config. """ def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None): if legacy_callbacks_dict: deprecation_warning( "callbacks dict interface", "a class extending rllib.agents.callbacks.DefaultCallbacks") self.legacy_callbacks = legacy_callbacks_dict or {} def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, **kwargs): """Callback run on the rollout worker before each episode starts. Args: worker (RolloutWorker): Reference to the current rollout worker. base_env (BaseEnv): BaseEnv running the episode. The underlying env object can be gotten by calling base_env.get_unwrapped(). policies (dict): Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy. episode (MultiAgentEpisode): Episode object which contains episode state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. kwargs: Forward compatibility placeholder. """ if self.legacy_callbacks.get("on_episode_start"): self.legacy_callbacks["on_episode_start"]({ "env": base_env, "policy": policies, "episode": episode, }) def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, **kwargs): """Runs on each episode step. Args: worker (RolloutWorker): Reference to the current rollout worker. base_env (BaseEnv): BaseEnv running the episode. The underlying env object can be gotten by calling base_env.get_unwrapped(). episode (MultiAgentEpisode): Episode object which contains episode state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. kwargs: Forward compatibility placeholder. """ if self.legacy_callbacks.get("on_episode_step"): self.legacy_callbacks["on_episode_step"]({ "env": base_env, "episode": episode }) def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, **kwargs): """Runs when an episode is done. Args: worker (RolloutWorker): Reference to the current rollout worker. base_env (BaseEnv): BaseEnv running the episode. The underlying env object can be gotten by calling base_env.get_unwrapped(). policies (dict): Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy. episode (MultiAgentEpisode): Episode object which contains episode state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. kwargs: Forward compatibility placeholder. """ if self.legacy_callbacks.get("on_episode_end"): self.legacy_callbacks["on_episode_end"]({ "env": base_env, "policy": policies, "episode": episode, }) def on_postprocess_trajectory( self, worker: RolloutWorker, episode: MultiAgentEpisode, agent_id: AgentID, policy_id: PolicyID, policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch, original_batches: Dict[AgentID, SampleBatch], **kwargs): """Called immediately after a policy's postprocess_fn is called. You can use this callback to do additional postprocessing for a policy, including looking at the trajectory data of other agents in multi-agent settings. Args: worker (RolloutWorker): Reference to the current rollout worker. episode (MultiAgentEpisode): Episode object. agent_id (str): Id of the current agent. policy_id (str): Id of the current policy for the agent. policies (dict): Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy. postprocessed_batch (SampleBatch): The postprocessed sample batch for this agent. You can mutate this object to apply your own trajectory postprocessing. original_batches (dict): Mapping of agents to their unpostprocessed trajectory data. You should not mutate this object. kwargs: Forward compatibility placeholder. """ if self.legacy_callbacks.get("on_postprocess_traj"): self.legacy_callbacks["on_postprocess_traj"]({ "episode": episode, "agent_id": agent_id, "pre_batch": original_batches[agent_id], "post_batch": postprocessed_batch, "all_pre_batches": original_batches, }) def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch, **kwargs): """Called at the end RolloutWorker.sample(). Args: worker (RolloutWorker): Reference to the current rollout worker. samples (SampleBatch): Batch to be returned. You can mutate this object to modify the samples generated. kwargs: Forward compatibility placeholder. """ if self.legacy_callbacks.get("on_sample_end"): self.legacy_callbacks["on_sample_end"]({ "worker": worker, "samples": samples, }) def on_train_result(self, trainer, result: dict, **kwargs): """Called at the end of Trainable.train(). Args: trainer (Trainer): Current trainer instance. result (dict): Dict of results returned from trainer.train() call. You can mutate this object to add additional metrics. kwargs: Forward compatibility placeholder. """ if self.legacy_callbacks.get("on_train_result"): self.legacy_callbacks["on_train_result"]({ "trainer": trainer, "result": result, })