"""Example of using RLlib's debug callbacks. Here we use callbacks to track the average CartPole pole angle magnitude as a custom metric. """ from typing import Dict import argparse import numpy as np import ray from ray import tune from ray.rllib.env import BaseEnv from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker from ray.rllib.agents.callbacks import DefaultCallbacks class MyCallbacks(DefaultCallbacks): def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, **kwargs): print("episode {} started".format(episode.episode_id)) episode.user_data["pole_angles"] = [] episode.hist_data["pole_angles"] = [] def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, **kwargs): pole_angle = abs(episode.last_observation_for()[2]) raw_angle = abs(episode.last_raw_obs_for()[2]) assert pole_angle == raw_angle episode.user_data["pole_angles"].append(pole_angle) def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, **kwargs): pole_angle = np.mean(episode.user_data["pole_angles"]) print("episode {} ended with length {} and pole angles {}".format( episode.episode_id, episode.length, pole_angle)) episode.custom_metrics["pole_angle"] = pole_angle episode.hist_data["pole_angles"] = episode.user_data["pole_angles"] def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch, **kwargs): print("returned sample batch of size {}".format(samples.count)) def on_train_result(self, trainer, result: dict, **kwargs): print("trainer.train() result: {} -> {} episodes".format( trainer, result["episodes_this_iter"])) # you can mutate the result dict to add new fields to return result["callback_ok"] = True def on_postprocess_trajectory( self, worker: RolloutWorker, episode: MultiAgentEpisode, agent_id: str, policy_id: str, policies: Dict[str, Policy], postprocessed_batch: SampleBatch, original_batches: Dict[str, SampleBatch], **kwargs): print("postprocessed {} steps".format(postprocessed_batch.count)) if "num_batches" not in episode.custom_metrics: episode.custom_metrics["num_batches"] = 0 episode.custom_metrics["num_batches"] += 1 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--num-iters", type=int, default=2000) args = parser.parse_args() ray.init() trials = tune.run( "PG", stop={ "training_iteration": args.num_iters, }, config={ "env": "CartPole-v0", "callbacks": MyCallbacks, }, return_trials=True) # verify custom metrics for integration tests custom_metrics = trials[0].last_result["custom_metrics"] print(custom_metrics) assert "pole_angle_mean" in custom_metrics assert "pole_angle_min" in custom_metrics assert "pole_angle_max" in custom_metrics assert "num_batches_mean" in custom_metrics assert "callback_ok" in trials[0].last_result