mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
fixing polynomial schedule horizon (#7795)
This commit is contained in:
parent
067bbb6710
commit
13c2e13120
2 changed files with 5 additions and 3 deletions
|
@ -35,5 +35,6 @@ class PolynomialSchedule(Schedule):
|
||||||
Returns the result of:
|
Returns the result of:
|
||||||
final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power
|
final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power
|
||||||
"""
|
"""
|
||||||
|
t = min(t, self.schedule_timesteps)
|
||||||
return self.final_p + (self.initial_p - self.final_p) * (
|
return self.final_p + (self.initial_p - self.final_p) * (
|
||||||
1.0 - (t / self.schedule_timesteps))**self.power
|
1.0 - (t / self.schedule_timesteps))**self.power
|
||||||
|
|
|
@ -28,7 +28,7 @@ class TestSchedules(unittest.TestCase):
|
||||||
check(out, value)
|
check(out, value)
|
||||||
|
|
||||||
def test_linear_schedule(self):
|
def test_linear_schedule(self):
|
||||||
ts = [0, 50, 10, 100, 90, 2, 1, 99, 23]
|
ts = [0, 50, 10, 100, 90, 2, 1, 99, 23, 1000]
|
||||||
config = {"schedule_timesteps": 100, "initial_p": 2.1, "final_p": 0.6}
|
config = {"schedule_timesteps": 100, "initial_p": 2.1, "final_p": 0.6}
|
||||||
|
|
||||||
for fw in framework_iterator(
|
for fw in framework_iterator(
|
||||||
|
@ -37,10 +37,10 @@ class TestSchedules(unittest.TestCase):
|
||||||
linear = from_config(LinearSchedule, config, framework=fw_)
|
linear = from_config(LinearSchedule, config, framework=fw_)
|
||||||
for t in ts:
|
for t in ts:
|
||||||
out = linear(t)
|
out = linear(t)
|
||||||
check(out, 2.1 - (t / 100) * (2.1 - 0.6), decimals=4)
|
check(out, 2.1 - (min(t, 100) / 100) * (2.1 - 0.6), decimals=4)
|
||||||
|
|
||||||
def test_polynomial_schedule(self):
|
def test_polynomial_schedule(self):
|
||||||
ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
|
ts = [0, 5, 10, 100, 90, 2, 1, 99, 23, 1000]
|
||||||
config = dict(
|
config = dict(
|
||||||
type="ray.rllib.utils.schedules.polynomial_schedule."
|
type="ray.rllib.utils.schedules.polynomial_schedule."
|
||||||
"PolynomialSchedule",
|
"PolynomialSchedule",
|
||||||
|
@ -55,6 +55,7 @@ class TestSchedules(unittest.TestCase):
|
||||||
polynomial = from_config(config, framework=fw_)
|
polynomial = from_config(config, framework=fw_)
|
||||||
for t in ts:
|
for t in ts:
|
||||||
out = polynomial(t)
|
out = polynomial(t)
|
||||||
|
t = min(t, 100)
|
||||||
check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
|
check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
|
||||||
|
|
||||||
def test_exponential_schedule(self):
|
def test_exponential_schedule(self):
|
||||||
|
|
Loading…
Add table
Reference in a new issue