mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
55 lines
2 KiB
Python
55 lines
2 KiB
Python
![]() |
from ray.rllib.utils.schedules.schedule import Schedule
|
||
|
|
||
|
|
||
|
def _linear_interpolation(l, r, alpha):
|
||
|
return l + alpha * (r - l)
|
||
|
|
||
|
|
||
|
class PiecewiseSchedule(Schedule):
|
||
|
def __init__(self,
|
||
|
endpoints,
|
||
|
interpolation=_linear_interpolation,
|
||
|
outside_value=None,
|
||
|
framework=None):
|
||
|
"""
|
||
|
Args:
|
||
|
endpoints (List[Tuple[int,float]]): A list of tuples
|
||
|
`(t, value)` such that the output
|
||
|
is an interpolation (given by the `interpolation` callable)
|
||
|
between two values.
|
||
|
E.g.
|
||
|
t=400 and endpoints=[(0, 20.0),(500, 30.0)]
|
||
|
output=20.0 + 0.8 * 10.0 = 28.0
|
||
|
NOTE: All the values for time must be sorted in an increasing
|
||
|
order.
|
||
|
|
||
|
interpolation (callable): A function that takes the left-value,
|
||
|
the right-value and an alpha interpolation parameter
|
||
|
(0.0=only left value, 1.0=only right value), which is the
|
||
|
fraction of distance from left endpoint to right endpoint.
|
||
|
|
||
|
outside_value (Optional[float]): If t_pct in call to `value` is
|
||
|
outside of all the intervals in `endpoints` this value is
|
||
|
returned. If None then an AssertionError is raised when outside
|
||
|
value is requested.
|
||
|
"""
|
||
|
# TODO(sven): support tf.
|
||
|
assert framework is None
|
||
|
super().__init__(framework=None)
|
||
|
|
||
|
idxes = [e[0] for e in endpoints]
|
||
|
assert idxes == sorted(idxes)
|
||
|
self.interpolation = interpolation
|
||
|
self.outside_value = outside_value
|
||
|
self.endpoints = endpoints
|
||
|
|
||
|
def value(self, t):
|
||
|
for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
|
||
|
if l_t <= t < r_t:
|
||
|
alpha = float(t - l_t) / (r_t - l_t)
|
||
|
return self.interpolation(l, r, alpha)
|
||
|
|
||
|
# t does not belong to any of the pieces, so doom.
|
||
|
assert self.outside_value is not None
|
||
|
return self.outside_value
|