mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add a unittest for learning rate schedule used with APEX agent. (#18389)
This commit is contained in:
parent
c91e0eb065
commit
808b683f81
2 changed files with 61 additions and 0 deletions
|
@ -79,6 +79,8 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Learning rate for adam optimizer
|
||||
"lr": 5e-4,
|
||||
# Learning rate schedule
|
||||
# In the format of [[timestep, value], [timestep, value], ...]
|
||||
# A schedule should normally start from timestep 0.
|
||||
"lr_schedule": None,
|
||||
# Adam epsilon hyper parameter
|
||||
"adam_epsilon": 1e-8,
|
||||
|
|
|
@ -3,6 +3,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
import ray.rllib.agents.dqn.apex as apex
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
|
||||
framework_iterator
|
||||
|
||||
|
@ -62,6 +63,64 @@ class TestApexDQN(unittest.TestCase):
|
|||
|
||||
trainer.stop()
|
||||
|
||||
def test_apex_lr_schedule(self):
|
||||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 1
|
||||
config["num_gpus"] = 0
|
||||
config["buffer_size"] = 100
|
||||
config["learning_starts"] = 10
|
||||
config["train_batch_size"] = 10
|
||||
config["rollout_fragment_length"] = 5
|
||||
config["prioritized_replay"] = True
|
||||
config["timesteps_per_iteration"] = 10
|
||||
# 0 metrics reporting delay, this makes sure timestep,
|
||||
# which lr depends on, is updated after each worker rollout.
|
||||
config["min_iter_time_s"] = 0
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
# This makes sure learning schedule is checked every 10 timesteps.
|
||||
config["optimizer"]["max_weight_sync_delay"] = 10
|
||||
# Initial lr, doesn't really matter because of the schedule below.
|
||||
config["lr"] = 0.2
|
||||
lr_schedule = [
|
||||
[0, 0.2],
|
||||
[50, 0.1],
|
||||
[100, 0.01],
|
||||
[150, 0.001],
|
||||
]
|
||||
config["lr_schedule"] = lr_schedule
|
||||
|
||||
def _step_n_times(trainer, n: int):
|
||||
"""Step trainer n times.
|
||||
|
||||
Returns:
|
||||
learning rate at the end of the execution.
|
||||
"""
|
||||
for _ in range(n):
|
||||
results = trainer.train()
|
||||
return results["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"]
|
||||
|
||||
# Check eager execution frameworks here, since it's easier to control
|
||||
# exact timesteps with these frameworks.
|
||||
for _ in framework_iterator(config):
|
||||
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
|
||||
|
||||
lr = _step_n_times(trainer, 5) # 50 timesteps
|
||||
# PiecewiseSchedule does interpolation. So roughly 0.1 here.
|
||||
self.assertLessEqual(lr, 0.15)
|
||||
self.assertGreaterEqual(lr, 0.05)
|
||||
|
||||
lr = _step_n_times(trainer, 5) # 100 timesteps
|
||||
# PiecewiseSchedule does interpolation. So roughly 0.01 here.
|
||||
self.assertLessEqual(lr, 0.02)
|
||||
self.assertGreaterEqual(lr, 0.005)
|
||||
|
||||
lr = _step_n_times(trainer, 5) # 150 timesteps
|
||||
# PiecewiseSchedule does interpolation. So roughly 0.001 here.
|
||||
self.assertLessEqual(lr, 0.002)
|
||||
self.assertGreaterEqual(lr, 0.0005)
|
||||
|
||||
trainer.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
|
Loading…
Add table
Reference in a new issue