ray/rllib/utils/schedules/exponential_schedule.py

42 lines
1.6 KiB
Python

from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules.schedule import Schedule
torch, _ = try_import_torch()
class ExponentialSchedule(Schedule):
def __init__(self,
schedule_timesteps,
framework,
initial_p=1.0,
decay_rate=0.1):
"""
Exponential decay schedule from initial_p to final_p over
schedule_timesteps. After this many time steps always `final_p` is
returned.
Agrs:
schedule_timesteps (int): Number of time steps for which to
linearly anneal initial_p to final_p
initial_p (float): Initial output value.
decay_rate (float): The percentage of the original value after
100% of the time has been reached (see formula above).
>0.0: The smaller the decay-rate, the stronger the decay.
1.0: No decay at all.
framework (Optional[str]): One of "tf", "torch", or None.
"""
super().__init__(framework=framework)
assert schedule_timesteps > 0
self.schedule_timesteps = schedule_timesteps
self.initial_p = initial_p
self.decay_rate = decay_rate
@override(Schedule)
def _value(self, t):
"""Returns the result of: initial_p * decay_rate ** (`t`/t_max)
"""
if self.framework == "torch" and torch and isinstance(t, torch.Tensor):
t = t.float()
return self.initial_p * \
self.decay_rate ** (t / self.schedule_timesteps)