ray/rllib/utils/schedules/schedule.py

54 lines
1.5 KiB
Python
Raw Normal View History

from abc import ABCMeta, abstractmethod
from ray.rllib.utils.framework import check_framework
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
class Schedule(metaclass=ABCMeta):
"""
Schedule classes implement various time-dependent scheduling schemas, such
as:
- Constant behavior.
- Linear decay.
- Piecewise decay.
Useful for backend-agnostic rate/weight changes for learning rates,
exploration epsilons, beta parameters for prioritized replay, loss weights
decay, etc..
Each schedule can be called directly with the `t` (absolute time step)
value and returns the value dependent on the Schedule and the passed time.
"""
def __init__(self, framework=None):
self.framework = check_framework(framework)
@abstractmethod
def _value(self, t):
"""
Returns the value based on a time step input.
Args:
t (int): The time step. This could be a tf.Tensor.
Returns:
any: The calculated value depending on the schedule and `t`.
"""
raise NotImplementedError
def value(self, t):
if self.framework == "tf":
return tf.cast(
tf.py_function(self._value, [t], tf.float64),
tf.float32,
name="schedule_value")
return self._value(t)
def __call__(self, t):
"""
Simply calls `self.value(t)`.
"""
return self.value(t)