2021-12-15 22:32:52 +01:00
|
|
|
from typing import Optional
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
2020-07-30 12:49:32 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2020-01-28 20:07:55 +01:00
|
|
|
from ray.rllib.utils.schedules.schedule import Schedule
|
2021-12-15 22:32:52 +01:00
|
|
|
from ray.rllib.utils.typing import TensorType
|
2020-01-28 20:07:55 +01:00
|
|
|
|
2020-07-30 12:49:32 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@PublicAPI
|
2020-01-28 20:07:55 +01:00
|
|
|
class ConstantSchedule(Schedule):
|
2021-12-15 22:32:52 +01:00
|
|
|
"""A Schedule where the value remains constant over time."""
|
|
|
|
|
|
|
|
def __init__(self, value: float, framework: Optional[str] = None):
|
|
|
|
"""Initializes a ConstantSchedule instance.
|
2020-01-28 20:07:55 +01:00
|
|
|
|
|
|
|
Args:
|
2021-12-15 22:32:52 +01:00
|
|
|
value: The constant value to return, independently of time.
|
|
|
|
framework: The framework descriptor string, e.g. "tf",
|
|
|
|
"torch", or None.
|
2020-01-28 20:07:55 +01:00
|
|
|
"""
|
2020-03-10 11:14:14 -07:00
|
|
|
super().__init__(framework=framework)
|
2020-01-28 20:07:55 +01:00
|
|
|
self._v = value
|
|
|
|
|
2020-05-04 23:53:38 +02:00
|
|
|
@override(Schedule)
|
2021-12-15 22:32:52 +01:00
|
|
|
def _value(self, t: TensorType) -> TensorType:
|
2020-01-28 20:07:55 +01:00
|
|
|
return self._v
|
2020-07-30 12:49:32 +02:00
|
|
|
|
|
|
|
@override(Schedule)
|
2021-12-15 22:32:52 +01:00
|
|
|
def _tf_value_op(self, t: TensorType) -> TensorType:
|
2020-07-30 12:49:32 +02:00
|
|
|
return tf.constant(self._v)
|