ray/rllib/utils/schedules.py

108 lines
3.8 KiB
Python

"""This file is used for specifying various schedules that evolve over
time throughout the execution of the algorithm, such as:
- learning rate for the optimizer
- exploration epsilon for the epsilon greedy exploration strategy
- beta parameter for beta parameter in prioritized replay
Each schedule has a function `value(t)` which returns the current value
of the parameter given the timestep t of the optimization procedure.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class Schedule(object):
def value(self, t):
"""Value of the schedule at time t"""
raise NotImplementedError()
class ConstantSchedule(object):
def __init__(self, value):
"""Value remains constant over time.
Parameters
----------
value: float
Constant value of the schedule
"""
self._v = value
def value(self, t):
"""See Schedule.value"""
return self._v
def linear_interpolation(l, r, alpha):
return l + alpha * (r - l)
class PiecewiseSchedule(object):
def __init__(self,
endpoints,
interpolation=linear_interpolation,
outside_value=None):
"""Piecewise schedule.
endpoints: [(int, int)]
list of pairs `(time, value)` meanining that schedule should output
`value` when `t==time`. All the values for time must be sorted in
an increasing order. When t is between two times, e.g.
`(time_a, value_a)`
and `(time_b, value_b)`, such that `time_a <= t < time_b` then value
outputs `interpolation(value_a, value_b, alpha)` where alpha is a
fraction of time passed between `time_a` and `time_b` for time `t`.
interpolation: lambda float, float, float: float
a function that takes value to the left and to the right of t
according to the `endpoints`. Alpha is the fraction of distance from
left endpoint to right endpoint that t has covered. See
linear_interpolation for example.
outside_value: float
if the value is requested outside of all the intervals sepecified in
`endpoints` this value is returned. If None then AssertionError is
raised when outside value is requested.
"""
idxes = [e[0] for e in endpoints]
assert idxes == sorted(idxes)
self._interpolation = interpolation
self._outside_value = outside_value
self._endpoints = endpoints
def value(self, t):
"""See Schedule.value"""
for (l_t, l), (r_t, r) in zip(self._endpoints[:-1],
self._endpoints[1:]):
if l_t <= t and t < r_t:
alpha = float(t - l_t) / (r_t - l_t)
return self._interpolation(l, r, alpha)
# t does not belong to any of the pieces, so doom.
assert self._outside_value is not None
return self._outside_value
class LinearSchedule(object):
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
"""Linear interpolation between initial_p and final_p over
schedule_timesteps. After this many timesteps pass final_p is
returned.
Parameters
----------
schedule_timesteps: int
Number of timesteps for which to linearly anneal initial_p
to final_p
initial_p: float
initial output value
final_p: float
final output value
"""
self.schedule_timesteps = schedule_timesteps
self.final_p = final_p
self.initial_p = initial_p
def value(self, t):
"""See Schedule.value"""
fraction = min(float(t) / max(1, self.schedule_timesteps), 1.0)
return self.initial_p + fraction * (self.final_p - self.initial_p)