mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Discussion 1928: Initial lr wrong if schedule used that includes ts=0 (both tf and torch). (#15538)
This commit is contained in:
parent
f5be8d8f74
commit
78b776942f
3 changed files with 36 additions and 20 deletions
|
@ -4,8 +4,8 @@ import ray
|
|||
import ray.rllib.agents.impala as impala
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
||||
framework_iterator
|
||||
from ray.rllib.utils.test_utils import check, \
|
||||
check_compute_single_action, framework_iterator
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
@ -48,22 +48,32 @@ class TestIMPALA(unittest.TestCase):
|
|||
|
||||
def test_impala_lr_schedule(self):
|
||||
config = impala.DEFAULT_CONFIG.copy()
|
||||
# Test whether we correctly ignore the "lr" setting.
|
||||
# The first lr should be 0.0005.
|
||||
config["lr"] = 0.1
|
||||
config["lr_schedule"] = [
|
||||
[0, 0.0005],
|
||||
[10000, 0.000001],
|
||||
]
|
||||
local_cfg = config.copy()
|
||||
trainer = impala.ImpalaTrainer(config=local_cfg, env="CartPole-v0")
|
||||
config["env"] = "CartPole-v0"
|
||||
|
||||
def get_lr(result):
|
||||
return result["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"]
|
||||
|
||||
try:
|
||||
r1 = trainer.train()
|
||||
r2 = trainer.train()
|
||||
assert get_lr(r2) < get_lr(r1), (r1, r2)
|
||||
finally:
|
||||
trainer.stop()
|
||||
for fw in framework_iterator(config, frameworks=("tf", "torch")):
|
||||
trainer = impala.ImpalaTrainer(config=config)
|
||||
policy = trainer.get_policy()
|
||||
|
||||
try:
|
||||
if fw == "tf":
|
||||
check(policy._sess.run(policy.cur_lr), 0.0005)
|
||||
else:
|
||||
check(policy.cur_lr, 0.0005)
|
||||
r1 = trainer.train()
|
||||
r2 = trainer.train()
|
||||
assert get_lr(r2) < get_lr(r1), (r1, r2)
|
||||
finally:
|
||||
trainer.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -946,11 +946,15 @@ class LearningRateSchedule:
|
|||
|
||||
@DeveloperAPI
|
||||
def __init__(self, lr, lr_schedule):
|
||||
self.cur_lr = tf1.get_variable("lr", initializer=lr, trainable=False)
|
||||
self._lr_schedule = lr_schedule
|
||||
if self._lr_schedule is not None:
|
||||
self._lr_schedule = None
|
||||
if lr_schedule is None:
|
||||
self.cur_lr = tf1.get_variable(
|
||||
"lr", initializer=lr, trainable=False)
|
||||
else:
|
||||
self._lr_schedule = PiecewiseSchedule(
|
||||
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
|
||||
self.cur_lr = tf1.get_variable(
|
||||
"lr", initializer=self._lr_schedule.value(0), trainable=False)
|
||||
if self.framework == "tf":
|
||||
self._lr_placeholder = tf1.placeholder(
|
||||
dtype=tf.float32, name="lr")
|
||||
|
|
|
@ -880,20 +880,22 @@ class LearningRateSchedule:
|
|||
|
||||
@DeveloperAPI
|
||||
def __init__(self, lr, lr_schedule):
|
||||
self.cur_lr = lr
|
||||
self._lr_schedule = None
|
||||
if lr_schedule is None:
|
||||
self.lr_schedule = ConstantSchedule(lr, framework=None)
|
||||
self.cur_lr = lr
|
||||
else:
|
||||
self.lr_schedule = PiecewiseSchedule(
|
||||
self._lr_schedule = PiecewiseSchedule(
|
||||
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
|
||||
self.cur_lr = self._lr_schedule.value(0)
|
||||
|
||||
@override(Policy)
|
||||
def on_global_var_update(self, global_vars):
|
||||
super().on_global_var_update(global_vars)
|
||||
self.cur_lr = self.lr_schedule.value(global_vars["timestep"])
|
||||
for opt in self._optimizers:
|
||||
for p in opt.param_groups:
|
||||
p["lr"] = self.cur_lr
|
||||
if self._lr_schedule:
|
||||
self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
|
||||
for opt in self._optimizers:
|
||||
for p in opt.param_groups:
|
||||
p["lr"] = self.cur_lr
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
|
Loading…
Add table
Reference in a new issue