2018-11-03 18:48:32 -07:00
|
|
|
"""Example of using RLlib's debug callbacks.
|
|
|
|
|
|
|
|
Here we use callbacks to track the average CartPole pole angle magnitude as a
|
|
|
|
custom metric.
|
|
|
|
"""
|
|
|
|
|
2020-04-17 02:06:42 +03:00
|
|
|
from typing import Dict
|
2018-11-03 18:48:32 -07:00
|
|
|
import argparse
|
|
|
|
import numpy as np
|
2020-10-02 23:07:44 +02:00
|
|
|
import os
|
2018-11-03 18:48:32 -07:00
|
|
|
|
|
|
|
import ray
|
|
|
|
from ray import tune
|
2020-10-02 23:07:44 +02:00
|
|
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
2020-04-17 02:06:42 +03:00
|
|
|
from ray.rllib.env import BaseEnv
|
2020-10-02 23:07:44 +02:00
|
|
|
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
2020-04-17 02:06:42 +03:00
|
|
|
from ray.rllib.policy import Policy
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2020-10-02 23:07:44 +02:00
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--torch", action="store_true")
|
|
|
|
parser.add_argument("--stop-iters", type=int, default=2000)
|
2020-04-17 02:06:42 +03:00
|
|
|
|
|
|
|
|
|
|
|
class MyCallbacks(DefaultCallbacks):
|
2020-09-03 17:27:05 +02:00
|
|
|
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
2020-04-17 02:06:42 +03:00
|
|
|
policies: Dict[str, Policy],
|
2020-09-03 17:27:05 +02:00
|
|
|
episode: MultiAgentEpisode, env_index: int, **kwargs):
|
2021-04-11 13:16:17 +02:00
|
|
|
# Make sure this episode has just been started (only initial obs
|
|
|
|
# logged so far).
|
|
|
|
assert episode.length == 0, \
|
|
|
|
"ERROR: `on_episode_start()` callback should be called right " \
|
|
|
|
"after env reset!"
|
2020-09-03 17:27:05 +02:00
|
|
|
print("episode {} (env-idx={}) started.".format(
|
|
|
|
episode.episode_id, env_index))
|
2020-04-17 02:06:42 +03:00
|
|
|
episode.user_data["pole_angles"] = []
|
|
|
|
episode.hist_data["pole_angles"] = []
|
|
|
|
|
2020-09-03 17:27:05 +02:00
|
|
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
|
|
|
episode: MultiAgentEpisode, env_index: int, **kwargs):
|
2021-04-11 13:16:17 +02:00
|
|
|
# Make sure this episode is ongoing.
|
|
|
|
assert episode.length > 0, \
|
|
|
|
"ERROR: `on_episode_step()` callback should not be called right " \
|
|
|
|
"after env reset!"
|
2020-04-17 02:06:42 +03:00
|
|
|
pole_angle = abs(episode.last_observation_for()[2])
|
|
|
|
raw_angle = abs(episode.last_raw_obs_for()[2])
|
|
|
|
assert pole_angle == raw_angle
|
|
|
|
episode.user_data["pole_angles"].append(pole_angle)
|
|
|
|
|
2020-09-03 17:27:05 +02:00
|
|
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
2020-04-17 02:06:42 +03:00
|
|
|
policies: Dict[str, Policy], episode: MultiAgentEpisode,
|
2020-09-03 17:27:05 +02:00
|
|
|
env_index: int, **kwargs):
|
2021-04-11 13:16:17 +02:00
|
|
|
# Make sure this episode is really done.
|
|
|
|
assert episode.batch_builder.policy_collectors[
|
|
|
|
"default_policy"].buffers["dones"][-1], \
|
|
|
|
"ERROR: `on_episode_end()` should only be called " \
|
|
|
|
"after episode is done!"
|
2020-04-17 02:06:42 +03:00
|
|
|
pole_angle = np.mean(episode.user_data["pole_angles"])
|
2020-09-03 17:27:05 +02:00
|
|
|
print("episode {} (env-idx={}) ended with length {} and pole "
|
|
|
|
"angles {}".format(episode.episode_id, env_index, episode.length,
|
|
|
|
pole_angle))
|
2020-04-17 02:06:42 +03:00
|
|
|
episode.custom_metrics["pole_angle"] = pole_angle
|
|
|
|
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
|
|
|
|
|
2020-09-03 17:27:05 +02:00
|
|
|
def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch,
|
2020-04-17 02:06:42 +03:00
|
|
|
**kwargs):
|
|
|
|
print("returned sample batch of size {}".format(samples.count))
|
|
|
|
|
2020-09-03 17:27:05 +02:00
|
|
|
def on_train_result(self, *, trainer, result: dict, **kwargs):
|
2020-04-17 02:06:42 +03:00
|
|
|
print("trainer.train() result: {} -> {} episodes".format(
|
|
|
|
trainer, result["episodes_this_iter"]))
|
|
|
|
# you can mutate the result dict to add new fields to return
|
|
|
|
result["callback_ok"] = True
|
|
|
|
|
2021-02-08 15:02:19 +01:00
|
|
|
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
|
|
|
|
result: dict, **kwargs) -> None:
|
|
|
|
result["sum_actions_in_train_batch"] = np.sum(train_batch["actions"])
|
|
|
|
print("policy.learn_on_batch() result: {} -> sum actions: {}".format(
|
|
|
|
policy, result["sum_actions_in_train_batch"]))
|
|
|
|
|
2020-04-17 02:06:42 +03:00
|
|
|
def on_postprocess_trajectory(
|
2020-09-03 17:27:05 +02:00
|
|
|
self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
|
2020-04-17 02:06:42 +03:00
|
|
|
agent_id: str, policy_id: str, policies: Dict[str, Policy],
|
|
|
|
postprocessed_batch: SampleBatch,
|
|
|
|
original_batches: Dict[str, SampleBatch], **kwargs):
|
|
|
|
print("postprocessed {} steps".format(postprocessed_batch.count))
|
|
|
|
if "num_batches" not in episode.custom_metrics:
|
|
|
|
episode.custom_metrics["num_batches"] = 0
|
|
|
|
episode.custom_metrics["num_batches"] += 1
|
2019-04-07 00:36:18 -07:00
|
|
|
|
|
|
|
|
2018-11-03 18:48:32 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
ray.init()
|
2019-03-30 14:07:50 -07:00
|
|
|
trials = tune.run(
|
|
|
|
"PG",
|
|
|
|
stop={
|
2020-05-12 08:23:10 +02:00
|
|
|
"training_iteration": args.stop_iters,
|
2019-03-30 14:07:50 -07:00
|
|
|
},
|
|
|
|
config={
|
2018-11-03 18:48:32 -07:00
|
|
|
"env": "CartPole-v0",
|
2020-09-03 17:27:05 +02:00
|
|
|
"num_envs_per_worker": 2,
|
2020-04-17 02:06:42 +03:00
|
|
|
"callbacks": MyCallbacks,
|
2020-10-02 23:07:44 +02:00
|
|
|
"framework": "torch" if args.torch else "tf",
|
|
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
2020-09-05 15:34:53 -07:00
|
|
|
}).trials
|
2018-11-03 18:48:32 -07:00
|
|
|
|
2021-02-08 15:02:19 +01:00
|
|
|
# Verify episode-related custom metrics are there.
|
2018-11-03 18:48:32 -07:00
|
|
|
custom_metrics = trials[0].last_result["custom_metrics"]
|
|
|
|
print(custom_metrics)
|
2018-12-05 23:31:45 -08:00
|
|
|
assert "pole_angle_mean" in custom_metrics
|
|
|
|
assert "pole_angle_min" in custom_metrics
|
|
|
|
assert "pole_angle_max" in custom_metrics
|
2019-04-07 00:36:18 -07:00
|
|
|
assert "num_batches_mean" in custom_metrics
|
2018-12-03 23:15:43 -08:00
|
|
|
assert "callback_ok" in trials[0].last_result
|
2021-02-08 15:02:19 +01:00
|
|
|
|
|
|
|
# Verify `on_learn_on_batch` custom metrics are there (per policy).
|
2021-03-29 20:07:44 +02:00
|
|
|
if args.torch:
|
|
|
|
info_custom_metrics = custom_metrics["default_policy"]
|
|
|
|
print(info_custom_metrics)
|
|
|
|
assert "sum_actions_in_train_batch" in info_custom_metrics
|