ray/rllib/examples/custom_metrics_and_callbacks_legacy.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

95 lines
2.9 KiB
Python

"""Deprecated API; see custom_metrics_and_callbacks.py instead."""
import argparse
import numpy as np
import os
import ray
from ray import tune
def on_episode_start(info):
episode = info["episode"]
print("episode {} started".format(episode.episode_id))
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []
def on_episode_step(info):
episode = info["episode"]
pole_angle = abs(episode.last_observation_for()[2])
raw_angle = abs(episode.last_raw_obs_for()[2])
assert pole_angle == raw_angle
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(info):
episode = info["episode"]
pole_angle = np.mean(episode.user_data["pole_angles"])
print(
"episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, pole_angle
)
)
episode.custom_metrics["pole_angle"] = pole_angle
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
def on_sample_end(info):
print("returned sample batch of size {}".format(info["samples"].count))
def on_train_result(info):
print(
"trainer.train() result: {} -> {} episodes".format(
info["trainer"], info["result"]["episodes_this_iter"]
)
)
# you can mutate the result dict to add new fields to return
info["result"]["callback_ok"] = True
def on_postprocess_traj(info):
episode = info["episode"]
batch = info["post_batch"]
print("postprocessed {} steps".format(batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--stop-iters", type=int, default=2000)
args = parser.parse_args()
ray.init()
trials = tune.run(
"PG",
stop={
"training_iteration": args.stop_iters,
},
config={
"env": "CartPole-v0",
"callbacks": {
"on_episode_start": on_episode_start,
"on_episode_step": on_episode_step,
"on_episode_end": on_episode_end,
"on_sample_end": on_sample_end,
"on_train_result": on_train_result,
"on_postprocess_traj": on_postprocess_traj,
},
"framework": "tf",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
},
).trials
# verify custom metrics for integration tests
custom_metrics = trials[0].last_result["custom_metrics"]
print(custom_metrics)
assert "pole_angle_mean" in custom_metrics
assert "pole_angle_min" in custom_metrics
assert "pole_angle_max" in custom_metrics
assert "num_batches_mean" in custom_metrics
assert "callback_ok" in trials[0].last_result