[RLlib] Add a unittest for learning rate schedule used with APEX agent. (#18389)

This commit is contained in:
gjoliver 2021-09-08 14:29:40 -07:00 committed by GitHub
parent c91e0eb065
commit 808b683f81
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 0 deletions

View file

@ -79,6 +79,8 @@ DEFAULT_CONFIG = with_common_config({
# Learning rate for adam optimizer
"lr": 5e-4,
# Learning rate schedule
# In the format of [[timestep, value], [timestep, value], ...]
# A schedule should normally start from timestep 0.
"lr_schedule": None,
# Adam epsilon hyper parameter
"adam_epsilon": 1e-8,

View file

@ -3,6 +3,7 @@ import unittest
import ray
import ray.rllib.agents.dqn.apex as apex
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
framework_iterator
@ -62,6 +63,64 @@ class TestApexDQN(unittest.TestCase):
trainer.stop()
def test_apex_lr_schedule(self):
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["num_gpus"] = 0
config["buffer_size"] = 100
config["learning_starts"] = 10
config["train_batch_size"] = 10
config["rollout_fragment_length"] = 5
config["prioritized_replay"] = True
config["timesteps_per_iteration"] = 10
# 0 metrics reporting delay, this makes sure timestep,
# which lr depends on, is updated after each worker rollout.
config["min_iter_time_s"] = 0
config["optimizer"]["num_replay_buffer_shards"] = 1
# This makes sure learning schedule is checked every 10 timesteps.
config["optimizer"]["max_weight_sync_delay"] = 10
# Initial lr, doesn't really matter because of the schedule below.
config["lr"] = 0.2
lr_schedule = [
[0, 0.2],
[50, 0.1],
[100, 0.01],
[150, 0.001],
]
config["lr_schedule"] = lr_schedule
def _step_n_times(trainer, n: int):
"""Step trainer n times.
Returns:
learning rate at the end of the execution.
"""
for _ in range(n):
results = trainer.train()
return results["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"]
# Check eager execution frameworks here, since it's easier to control
# exact timesteps with these frameworks.
for _ in framework_iterator(config):
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
lr = _step_n_times(trainer, 5) # 50 timesteps
# PiecewiseSchedule does interpolation. So roughly 0.1 here.
self.assertLessEqual(lr, 0.15)
self.assertGreaterEqual(lr, 0.05)
lr = _step_n_times(trainer, 5) # 100 timesteps
# PiecewiseSchedule does interpolation. So roughly 0.01 here.
self.assertLessEqual(lr, 0.02)
self.assertGreaterEqual(lr, 0.005)
lr = _step_n_times(trainer, 5) # 150 timesteps
# PiecewiseSchedule does interpolation. So roughly 0.001 here.
self.assertLessEqual(lr, 0.002)
self.assertGreaterEqual(lr, 0.0005)
trainer.stop()
if __name__ == "__main__":
import sys