"""Example of using RLlib's debug callbacks. Here we use callbacks to track the average CartPole pole angle magnitude as a custom metric. """ from typing import Dict import argparse import numpy as np import os import ray from ray import tune from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.env import BaseEnv from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch parser = argparse.ArgumentParser() parser.add_argument("--torch", action="store_true") parser.add_argument("--stop-iters", type=int, default=2000) class MyCallbacks(DefaultCallbacks): def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, env_index: int, **kwargs): print("episode {} (env-idx={}) started.".format( episode.episode_id, env_index)) episode.user_data["pole_angles"] = [] episode.hist_data["pole_angles"] = [] def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, env_index: int, **kwargs): pole_angle = abs(episode.last_observation_for()[2]) raw_angle = abs(episode.last_raw_obs_for()[2]) assert pole_angle == raw_angle episode.user_data["pole_angles"].append(pole_angle) def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, env_index: int, **kwargs): pole_angle = np.mean(episode.user_data["pole_angles"]) print("episode {} (env-idx={}) ended with length {} and pole " "angles {}".format(episode.episode_id, env_index, episode.length, pole_angle)) episode.custom_metrics["pole_angle"] = pole_angle episode.hist_data["pole_angles"] = episode.user_data["pole_angles"] def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch, **kwargs): print("returned sample batch of size {}".format(samples.count)) def on_train_result(self, *, trainer, result: dict, **kwargs): print("trainer.train() result: {} -> {} episodes".format( trainer, result["episodes_this_iter"])) # you can mutate the result dict to add new fields to return result["callback_ok"] = True def on_postprocess_trajectory( self, *, worker: RolloutWorker, episode: MultiAgentEpisode, agent_id: str, policy_id: str, policies: Dict[str, Policy], postprocessed_batch: SampleBatch, original_batches: Dict[str, SampleBatch], **kwargs): print("postprocessed {} steps".format(postprocessed_batch.count)) if "num_batches" not in episode.custom_metrics: episode.custom_metrics["num_batches"] = 0 episode.custom_metrics["num_batches"] += 1 if __name__ == "__main__": args = parser.parse_args() ray.init() trials = tune.run( "PG", stop={ "training_iteration": args.stop_iters, }, config={ "env": "CartPole-v0", "num_envs_per_worker": 2, "callbacks": MyCallbacks, "framework": "torch" if args.torch else "tf", # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), }).trials # verify custom metrics for integration tests custom_metrics = trials[0].last_result["custom_metrics"] print(custom_metrics) assert "pole_angle_mean" in custom_metrics assert "pole_angle_min" in custom_metrics assert "pole_angle_max" in custom_metrics assert "num_batches_mean" in custom_metrics assert "callback_ok" in trials[0].last_result