diff --git a/rllib/examples/custom_metrics_and_callbacks.py b/rllib/examples/custom_metrics_and_callbacks.py index 99512cab9..ee8df41e4 100644 --- a/rllib/examples/custom_metrics_and_callbacks.py +++ b/rllib/examples/custom_metrics_and_callbacks.py @@ -84,5 +84,4 @@ if __name__ == "__main__": assert "pole_angle_min" in custom_metrics assert "pole_angle_max" in custom_metrics assert "num_batches_mean" in custom_metrics - assert type(custom_metrics["pole_angle_mean"]) is float assert "callback_ok" in trials[0].last_result diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 8fa4ecb19..67d51a42b 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -596,14 +596,24 @@ class EntropyCoeffSchedule(object): def __init__(self, entropy_coeff, entropy_coeff_schedule): self.entropy_coeff = tf.get_variable( "entropy_coeff", initializer=entropy_coeff, trainable=False) - self._entropy_schedule = entropy_coeff_schedule + + if entropy_coeff_schedule is None: + self.entropy_coeff_schedule = ConstantSchedule(entropy_coeff) + else: + # Allows for custom schedule similar to lr_schedule format + if isinstance(entropy_coeff_schedule, list): + self.entropy_coeff_schedule = PiecewiseSchedule( + entropy_coeff_schedule, + outside_value=entropy_coeff_schedule[-1][-1]) + else: + # Implements previous version but enforces outside_value + self.entropy_coeff_schedule = PiecewiseSchedule( + [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]], + outside_value=0.0) @override(Policy) def on_global_var_update(self, global_vars): super(EntropyCoeffSchedule, self).on_global_var_update(global_vars) - if self._entropy_schedule is not None: - self.entropy_coeff.load( - self.entropy_coeff.eval(session=self._sess) * - (1 - global_vars["timestep"] / - self.config["entropy_coeff_schedule"]), - session=self._sess) + self.entropy_coeff.load( + self.entropy_coeff_schedule.value(global_vars["timestep"]), + session=self._sess)