[RLlib] Discussion 1928: Initial lr wrong if schedule used that includes ts=0 (both tf and torch). (#15538)

This commit is contained in:
Sven Mika 2021-04-27 17:19:52 +02:00 committed by GitHub
parent f5be8d8f74
commit 78b776942f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 20 deletions

View file

@ -4,8 +4,8 @@ import ray
import ray.rllib.agents.impala as impala
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_compute_single_action, \
framework_iterator
from ray.rllib.utils.test_utils import check, \
check_compute_single_action, framework_iterator
tf1, tf, tfv = try_import_tf()
@ -48,22 +48,32 @@ class TestIMPALA(unittest.TestCase):
def test_impala_lr_schedule(self):
config = impala.DEFAULT_CONFIG.copy()
# Test whether we correctly ignore the "lr" setting.
# The first lr should be 0.0005.
config["lr"] = 0.1
config["lr_schedule"] = [
[0, 0.0005],
[10000, 0.000001],
]
local_cfg = config.copy()
trainer = impala.ImpalaTrainer(config=local_cfg, env="CartPole-v0")
config["env"] = "CartPole-v0"
def get_lr(result):
return result["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"]
try:
r1 = trainer.train()
r2 = trainer.train()
assert get_lr(r2) < get_lr(r1), (r1, r2)
finally:
trainer.stop()
for fw in framework_iterator(config, frameworks=("tf", "torch")):
trainer = impala.ImpalaTrainer(config=config)
policy = trainer.get_policy()
try:
if fw == "tf":
check(policy._sess.run(policy.cur_lr), 0.0005)
else:
check(policy.cur_lr, 0.0005)
r1 = trainer.train()
r2 = trainer.train()
assert get_lr(r2) < get_lr(r1), (r1, r2)
finally:
trainer.stop()
if __name__ == "__main__":

View file

@ -946,11 +946,15 @@ class LearningRateSchedule:
@DeveloperAPI
def __init__(self, lr, lr_schedule):
self.cur_lr = tf1.get_variable("lr", initializer=lr, trainable=False)
self._lr_schedule = lr_schedule
if self._lr_schedule is not None:
self._lr_schedule = None
if lr_schedule is None:
self.cur_lr = tf1.get_variable(
"lr", initializer=lr, trainable=False)
else:
self._lr_schedule = PiecewiseSchedule(
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
self.cur_lr = tf1.get_variable(
"lr", initializer=self._lr_schedule.value(0), trainable=False)
if self.framework == "tf":
self._lr_placeholder = tf1.placeholder(
dtype=tf.float32, name="lr")

View file

@ -880,20 +880,22 @@ class LearningRateSchedule:
@DeveloperAPI
def __init__(self, lr, lr_schedule):
self.cur_lr = lr
self._lr_schedule = None
if lr_schedule is None:
self.lr_schedule = ConstantSchedule(lr, framework=None)
self.cur_lr = lr
else:
self.lr_schedule = PiecewiseSchedule(
self._lr_schedule = PiecewiseSchedule(
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
self.cur_lr = self._lr_schedule.value(0)
@override(Policy)
def on_global_var_update(self, global_vars):
super().on_global_var_update(global_vars)
self.cur_lr = self.lr_schedule.value(global_vars["timestep"])
for opt in self._optimizers:
for p in opt.param_groups:
p["lr"] = self.cur_lr
if self._lr_schedule:
self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
for opt in self._optimizers:
for p in opt.param_groups:
p["lr"] = self.cur_lr
@DeveloperAPI