2020-01-28 20:07:55 +01:00
|
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
|
|
|
|
from ray.rllib.utils.framework import check_framework
|
2020-01-30 20:27:57 +01:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
|
|
|
|
|
|
tf = try_import_tf()
|
2020-01-28 20:07:55 +01:00
|
|
|
|
|
|
|
|
|
|
|
class Schedule(metaclass=ABCMeta):
|
|
|
|
"""
|
|
|
|
Schedule classes implement various time-dependent scheduling schemas, such
|
|
|
|
as:
|
|
|
|
- Constant behavior.
|
|
|
|
- Linear decay.
|
|
|
|
- Piecewise decay.
|
|
|
|
|
|
|
|
Useful for backend-agnostic rate/weight changes for learning rates,
|
|
|
|
exploration epsilons, beta parameters for prioritized replay, loss weights
|
|
|
|
decay, etc..
|
|
|
|
|
|
|
|
Each schedule can be called directly with the `t` (absolute time step)
|
|
|
|
value and returns the value dependent on the Schedule and the passed time.
|
|
|
|
"""
|
|
|
|
|
2020-03-10 11:14:14 -07:00
|
|
|
def __init__(self, framework):
|
2020-01-28 20:07:55 +01:00
|
|
|
self.framework = check_framework(framework)
|
|
|
|
|
|
|
|
@abstractmethod
|
2020-01-30 20:27:57 +01:00
|
|
|
def _value(self, t):
|
2020-01-28 20:07:55 +01:00
|
|
|
"""
|
2020-01-30 20:27:57 +01:00
|
|
|
Returns the value based on a time step input.
|
2020-01-28 20:07:55 +01:00
|
|
|
|
|
|
|
Args:
|
2020-01-30 20:27:57 +01:00
|
|
|
t (int): The time step. This could be a tf.Tensor.
|
2020-01-28 20:07:55 +01:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
any: The calculated value depending on the schedule and `t`.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2020-01-30 20:27:57 +01:00
|
|
|
def value(self, t):
|
2020-03-01 20:53:35 +01:00
|
|
|
if self.framework == "tf":
|
2020-01-30 20:27:57 +01:00
|
|
|
return tf.cast(
|
2020-03-01 20:53:35 +01:00
|
|
|
tf.py_function(self._value, [t], tf.float64),
|
2020-01-30 20:27:57 +01:00
|
|
|
tf.float32,
|
2020-03-01 20:53:35 +01:00
|
|
|
name="schedule_value")
|
2020-01-30 20:27:57 +01:00
|
|
|
return self._value(t)
|
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
def __call__(self, t):
|
|
|
|
"""
|
|
|
|
Simply calls `self.value(t)`.
|
|
|
|
"""
|
|
|
|
return self.value(t)
|