From 13c2e13120e2186c21359f657ac92a2aed9cf97b Mon Sep 17 00:00:00 2001 From: konichuvak Date: Wed, 27 May 2020 04:59:28 -0400 Subject: [PATCH] fixing polynomial schedule horizon (#7795) --- rllib/utils/schedules/polynomial_schedule.py | 1 + rllib/utils/schedules/tests/test_schedules.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/rllib/utils/schedules/polynomial_schedule.py b/rllib/utils/schedules/polynomial_schedule.py index d015ff205..f13767358 100644 --- a/rllib/utils/schedules/polynomial_schedule.py +++ b/rllib/utils/schedules/polynomial_schedule.py @@ -35,5 +35,6 @@ class PolynomialSchedule(Schedule): Returns the result of: final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power """ + t = min(t, self.schedule_timesteps) return self.final_p + (self.initial_p - self.final_p) * ( 1.0 - (t / self.schedule_timesteps))**self.power diff --git a/rllib/utils/schedules/tests/test_schedules.py b/rllib/utils/schedules/tests/test_schedules.py index b9b881797..2b3439183 100644 --- a/rllib/utils/schedules/tests/test_schedules.py +++ b/rllib/utils/schedules/tests/test_schedules.py @@ -28,7 +28,7 @@ class TestSchedules(unittest.TestCase): check(out, value) def test_linear_schedule(self): - ts = [0, 50, 10, 100, 90, 2, 1, 99, 23] + ts = [0, 50, 10, 100, 90, 2, 1, 99, 23, 1000] config = {"schedule_timesteps": 100, "initial_p": 2.1, "final_p": 0.6} for fw in framework_iterator( @@ -37,10 +37,10 @@ class TestSchedules(unittest.TestCase): linear = from_config(LinearSchedule, config, framework=fw_) for t in ts: out = linear(t) - check(out, 2.1 - (t / 100) * (2.1 - 0.6), decimals=4) + check(out, 2.1 - (min(t, 100) / 100) * (2.1 - 0.6), decimals=4) def test_polynomial_schedule(self): - ts = [0, 5, 10, 100, 90, 2, 1, 99, 23] + ts = [0, 5, 10, 100, 90, 2, 1, 99, 23, 1000] config = dict( type="ray.rllib.utils.schedules.polynomial_schedule." "PolynomialSchedule", @@ -55,6 +55,7 @@ class TestSchedules(unittest.TestCase): polynomial = from_config(config, framework=fw_) for t in ts: out = polynomial(t) + t = min(t, 100) check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4) def test_exponential_schedule(self):