Allow EntropyCoeffSchedule to accept custom schedule (#6158)

* modify tf_policy to enable EntropyCoeffSchedule to handle list, and avoid negative values under current implementation

* Update custom_metrics_and_callbacks.py

* Update tf_policy.py
This commit is contained in:
waldroje 2019-11-14 03:45:43 -05:00 committed by Eric Liang
parent e4565c9cc6
commit e4c0843f60
2 changed files with 17 additions and 8 deletions

View file

@ -84,5 +84,4 @@ if __name__ == "__main__":
assert "pole_angle_min" in custom_metrics
assert "pole_angle_max" in custom_metrics
assert "num_batches_mean" in custom_metrics
assert type(custom_metrics["pole_angle_mean"]) is float
assert "callback_ok" in trials[0].last_result

View file

@ -596,14 +596,24 @@ class EntropyCoeffSchedule(object):
def __init__(self, entropy_coeff, entropy_coeff_schedule):
self.entropy_coeff = tf.get_variable(
"entropy_coeff", initializer=entropy_coeff, trainable=False)
self._entropy_schedule = entropy_coeff_schedule
if entropy_coeff_schedule is None:
self.entropy_coeff_schedule = ConstantSchedule(entropy_coeff)
else:
# Allows for custom schedule similar to lr_schedule format
if isinstance(entropy_coeff_schedule, list):
self.entropy_coeff_schedule = PiecewiseSchedule(
entropy_coeff_schedule,
outside_value=entropy_coeff_schedule[-1][-1])
else:
# Implements previous version but enforces outside_value
self.entropy_coeff_schedule = PiecewiseSchedule(
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
outside_value=0.0)
@override(Policy)
def on_global_var_update(self, global_vars):
super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
if self._entropy_schedule is not None:
self.entropy_coeff.load(
self.entropy_coeff.eval(session=self._sess) *
(1 - global_vars["timestep"] /
self.config["entropy_coeff_schedule"]),
self.entropy_coeff_schedule.value(global_vars["timestep"]),
session=self._sess)