mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Allow EntropyCoeffSchedule to accept custom schedule (#6158)
* modify tf_policy to enable EntropyCoeffSchedule to handle list, and avoid negative values under current implementation * Update custom_metrics_and_callbacks.py * Update tf_policy.py
This commit is contained in:
parent
e4565c9cc6
commit
e4c0843f60
2 changed files with 17 additions and 8 deletions
|
@ -84,5 +84,4 @@ if __name__ == "__main__":
|
|||
assert "pole_angle_min" in custom_metrics
|
||||
assert "pole_angle_max" in custom_metrics
|
||||
assert "num_batches_mean" in custom_metrics
|
||||
assert type(custom_metrics["pole_angle_mean"]) is float
|
||||
assert "callback_ok" in trials[0].last_result
|
||||
|
|
|
@ -596,14 +596,24 @@ class EntropyCoeffSchedule(object):
|
|||
def __init__(self, entropy_coeff, entropy_coeff_schedule):
|
||||
self.entropy_coeff = tf.get_variable(
|
||||
"entropy_coeff", initializer=entropy_coeff, trainable=False)
|
||||
self._entropy_schedule = entropy_coeff_schedule
|
||||
|
||||
if entropy_coeff_schedule is None:
|
||||
self.entropy_coeff_schedule = ConstantSchedule(entropy_coeff)
|
||||
else:
|
||||
# Allows for custom schedule similar to lr_schedule format
|
||||
if isinstance(entropy_coeff_schedule, list):
|
||||
self.entropy_coeff_schedule = PiecewiseSchedule(
|
||||
entropy_coeff_schedule,
|
||||
outside_value=entropy_coeff_schedule[-1][-1])
|
||||
else:
|
||||
# Implements previous version but enforces outside_value
|
||||
self.entropy_coeff_schedule = PiecewiseSchedule(
|
||||
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
|
||||
outside_value=0.0)
|
||||
|
||||
@override(Policy)
|
||||
def on_global_var_update(self, global_vars):
|
||||
super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
|
||||
if self._entropy_schedule is not None:
|
||||
self.entropy_coeff.load(
|
||||
self.entropy_coeff.eval(session=self._sess) *
|
||||
(1 - global_vars["timestep"] /
|
||||
self.config["entropy_coeff_schedule"]),
|
||||
self.entropy_coeff_schedule.value(global_vars["timestep"]),
|
||||
session=self._sess)
|
||||
|
|
Loading…
Add table
Reference in a new issue