2020-04-17 02:06:42 +03:00
|
|
|
"""Deprecated API; see custom_metrics_and_callbacks.py instead."""
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import numpy as np
|
2020-10-02 23:07:44 +02:00
|
|
|
import os
|
2020-04-17 02:06:42 +03:00
|
|
|
|
|
|
|
import ray
|
2022-07-27 04:12:59 -07:00
|
|
|
from ray import air, tune
|
2020-04-17 02:06:42 +03:00
|
|
|
|
|
|
|
|
|
|
|
def on_episode_start(info):
|
|
|
|
episode = info["episode"]
|
|
|
|
print("episode {} started".format(episode.episode_id))
|
|
|
|
episode.user_data["pole_angles"] = []
|
|
|
|
episode.hist_data["pole_angles"] = []
|
|
|
|
|
|
|
|
|
|
|
|
def on_episode_step(info):
|
|
|
|
episode = info["episode"]
|
|
|
|
pole_angle = abs(episode.last_observation_for()[2])
|
|
|
|
raw_angle = abs(episode.last_raw_obs_for()[2])
|
|
|
|
assert pole_angle == raw_angle
|
|
|
|
episode.user_data["pole_angles"].append(pole_angle)
|
|
|
|
|
|
|
|
|
|
|
|
def on_episode_end(info):
|
|
|
|
episode = info["episode"]
|
|
|
|
pole_angle = np.mean(episode.user_data["pole_angles"])
|
|
|
|
print(
|
|
|
|
"episode {} ended with length {} and pole angles {}".format(
|
|
|
|
episode.episode_id, episode.length, pole_angle
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-17 02:06:42 +03:00
|
|
|
episode.custom_metrics["pole_angle"] = pole_angle
|
|
|
|
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
|
|
|
|
|
|
|
|
|
|
|
|
def on_sample_end(info):
|
|
|
|
print("returned sample batch of size {}".format(info["samples"].count))
|
|
|
|
|
|
|
|
|
|
|
|
def on_train_result(info):
|
|
|
|
print(
|
|
|
|
"trainer.train() result: {} -> {} episodes".format(
|
|
|
|
info["trainer"], info["result"]["episodes_this_iter"]
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-17 02:06:42 +03:00
|
|
|
# you can mutate the result dict to add new fields to return
|
|
|
|
info["result"]["callback_ok"] = True
|
|
|
|
|
|
|
|
|
|
|
|
def on_postprocess_traj(info):
|
|
|
|
episode = info["episode"]
|
|
|
|
batch = info["post_batch"]
|
|
|
|
print("postprocessed {} steps".format(batch.count))
|
|
|
|
if "num_batches" not in episode.custom_metrics:
|
|
|
|
episode.custom_metrics["num_batches"] = 0
|
|
|
|
episode.custom_metrics["num_batches"] += 1
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
2020-05-12 08:23:10 +02:00
|
|
|
parser.add_argument("--stop-iters", type=int, default=2000)
|
2020-04-17 02:06:42 +03:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
ray.init()
|
2022-07-27 04:12:59 -07:00
|
|
|
tuner = tune.Tuner(
|
2020-04-17 02:06:42 +03:00
|
|
|
"PG",
|
2022-07-27 04:12:59 -07:00
|
|
|
run_config=air.RunConfig(
|
|
|
|
stop={
|
|
|
|
"training_iteration": args.stop_iters,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
param_space={
|
2020-04-17 02:06:42 +03:00
|
|
|
"env": "CartPole-v0",
|
|
|
|
"callbacks": {
|
|
|
|
"on_episode_start": on_episode_start,
|
|
|
|
"on_episode_step": on_episode_step,
|
|
|
|
"on_episode_end": on_episode_end,
|
|
|
|
"on_sample_end": on_sample_end,
|
|
|
|
"on_train_result": on_train_result,
|
|
|
|
"on_postprocess_traj": on_postprocess_traj,
|
|
|
|
},
|
2020-05-27 16:19:13 +02:00
|
|
|
"framework": "tf",
|
2020-10-02 23:07:44 +02:00
|
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
2020-09-05 15:34:53 -07:00
|
|
|
},
|
2022-07-27 04:12:59 -07:00
|
|
|
)
|
|
|
|
results = tuner.fit()
|
2020-04-17 02:06:42 +03:00
|
|
|
|
|
|
|
# verify custom metrics for integration tests
|
2022-07-27 04:12:59 -07:00
|
|
|
custom_metrics = results.get_best_result().metrics["custom_metrics"]
|
2020-04-17 02:06:42 +03:00
|
|
|
print(custom_metrics)
|
|
|
|
assert "pole_angle_mean" in custom_metrics
|
|
|
|
assert "pole_angle_min" in custom_metrics
|
|
|
|
assert "pole_angle_max" in custom_metrics
|
|
|
|
assert "num_batches_mean" in custom_metrics
|
2022-07-27 04:12:59 -07:00
|
|
|
assert "callback_ok" in results.get_best_result().metrics
|