ray/rllib/policy/tf_mixins.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

356 lines
14 KiB
Python
Raw Normal View History

import logging
from typing import Dict, List, Union
import gym
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.framework import get_variable, try_import_tf
from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.tf_utils import make_tf_callable
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
LocalOptimizer,
ModelGradients,
TensorType,
)
logger = logging.getLogger(__name__)
tf1, tf, tfv = try_import_tf()
@DeveloperAPI
class LearningRateSchedule:
"""Mixin for TFPolicy that adds a learning rate schedule."""
@DeveloperAPI
def __init__(self, lr, lr_schedule):
self._lr_schedule = None
if lr_schedule is None:
self.cur_lr = tf1.get_variable("lr", initializer=lr, trainable=False)
else:
self._lr_schedule = PiecewiseSchedule(
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None
)
self.cur_lr = tf1.get_variable(
"lr", initializer=self._lr_schedule.value(0), trainable=False
)
if self.framework == "tf":
self._lr_placeholder = tf1.placeholder(dtype=tf.float32, name="lr")
self._lr_update = self.cur_lr.assign(
self._lr_placeholder, read_value=False
)
@override(Policy)
def on_global_var_update(self, global_vars):
super().on_global_var_update(global_vars)
if self._lr_schedule is not None:
new_val = self._lr_schedule.value(global_vars["timestep"])
if self.framework == "tf":
self.get_session().run(
self._lr_update, feed_dict={self._lr_placeholder: new_val}
)
else:
self.cur_lr.assign(new_val, read_value=False)
# This property (self._optimizer) is (still) accessible for
# both TFPolicy and any TFPolicy_eager.
self._optimizer.learning_rate.assign(self.cur_lr)
@override(TFPolicy)
def optimizer(self):
if self.framework == "tf":
return tf1.train.AdamOptimizer(learning_rate=self.cur_lr)
else:
return tf.keras.optimizers.Adam(self.cur_lr)
@DeveloperAPI
class EntropyCoeffSchedule:
"""Mixin for TFPolicy that adds entropy coeff decay."""
@DeveloperAPI
def __init__(self, entropy_coeff, entropy_coeff_schedule):
self._entropy_coeff_schedule = None
if entropy_coeff_schedule is None:
self.entropy_coeff = get_variable(
entropy_coeff, framework="tf", tf_name="entropy_coeff", trainable=False
)
else:
# Allows for custom schedule similar to lr_schedule format
if isinstance(entropy_coeff_schedule, list):
self._entropy_coeff_schedule = PiecewiseSchedule(
entropy_coeff_schedule,
outside_value=entropy_coeff_schedule[-1][-1],
framework=None,
)
else:
# Implements previous version but enforces outside_value
self._entropy_coeff_schedule = PiecewiseSchedule(
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
outside_value=0.0,
framework=None,
)
self.entropy_coeff = get_variable(
self._entropy_coeff_schedule.value(0),
framework="tf",
tf_name="entropy_coeff",
trainable=False,
)
if self.framework == "tf":
self._entropy_coeff_placeholder = tf1.placeholder(
dtype=tf.float32, name="entropy_coeff"
)
self._entropy_coeff_update = self.entropy_coeff.assign(
self._entropy_coeff_placeholder, read_value=False
)
@override(Policy)
def on_global_var_update(self, global_vars):
super().on_global_var_update(global_vars)
if self._entropy_coeff_schedule is not None:
new_val = self._entropy_coeff_schedule.value(global_vars["timestep"])
if self.framework == "tf":
self.get_session().run(
self._entropy_coeff_update,
feed_dict={self._entropy_coeff_placeholder: new_val},
)
else:
self.entropy_coeff.assign(new_val, read_value=False)
class KLCoeffMixin:
"""Assigns the `update_kl()` and other KL-related methods to a TFPolicy.
This is used in Algorithms to update the KL coefficient after each
learning step based on `config.kl_target` and the measured KL value
(from the train_batch).
"""
def __init__(self, config: AlgorithmConfigDict):
# The current KL value (as python float).
self.kl_coeff_val = config["kl_coeff"]
# The current KL value (as tf Variable for in-graph operations).
self.kl_coeff = get_variable(
float(self.kl_coeff_val),
tf_name="kl_coeff",
trainable=False,
framework=config["framework"],
)
# Constant target value.
self.kl_target = config["kl_target"]
if self.framework == "tf":
self._kl_coeff_placeholder = tf1.placeholder(
dtype=tf.float32, name="kl_coeff"
)
self._kl_coeff_update = self.kl_coeff.assign(
self._kl_coeff_placeholder, read_value=False
)
def update_kl(self, sampled_kl):
# Update the current KL value based on the recently measured value.
# Increase.
if sampled_kl > 2.0 * self.kl_target:
self.kl_coeff_val *= 1.5
# Decrease.
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeff_val *= 0.5
# No change.
else:
return self.kl_coeff_val
# Make sure, new value is also stored in graph/tf variable.
self._set_kl_coeff(self.kl_coeff_val)
# Return the current KL value.
return self.kl_coeff_val
def _set_kl_coeff(self, new_kl_coeff):
# Set the (off graph) value.
self.kl_coeff_val = new_kl_coeff
# Update the tf/tf2 Variable (via session call for tf or `assign`).
if self.framework == "tf":
self.get_session().run(
self._kl_coeff_update,
feed_dict={self._kl_coeff_placeholder: self.kl_coeff_val},
)
else:
self.kl_coeff.assign(self.kl_coeff_val, read_value=False)
@override(Policy)
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
state = super().get_state()
# Add current kl-coeff value.
state["current_kl_coeff"] = self.kl_coeff_val
return state
@override(Policy)
def set_state(self, state: dict) -> None:
# Set current kl-coeff value first.
self._set_kl_coeff(state.pop("current_kl_coeff", self.config["kl_coeff"]))
# Call super's set_state with rest of the state dict.
super().set_state(state)
class TargetNetworkMixin:
"""Assign the `update_target` method to the SimpleQTFPolicy
The function is called every `target_network_update_freq` steps by the
master learner.
"""
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: AlgorithmConfigDict,
):
trainable_q_func_vars = self.model.trainable_variables()
trainable_target_q_func_vars = self.target_model.trainable_variables()
@make_tf_callable(self.get_session())
def do_update():
# update_target_fn will be called periodically to copy Q network to
# target Q network
update_target_expr = []
assert len(trainable_q_func_vars) == len(trainable_target_q_func_vars), (
trainable_q_func_vars,
trainable_target_q_func_vars,
)
for var, var_target in zip(
trainable_q_func_vars, trainable_target_q_func_vars
):
update_target_expr.append(var_target.assign(var))
logger.debug("Update target op {}".format(var_target))
return tf.group(*update_target_expr)
self.update_target = do_update
@property
def q_func_vars(self):
if not hasattr(self, "_q_func_vars"):
self._q_func_vars = self.model.variables()
return self._q_func_vars
@property
def target_q_func_vars(self):
if not hasattr(self, "_target_q_func_vars"):
self._target_q_func_vars = self.target_model.variables()
return self._target_q_func_vars
class ValueNetworkMixin:
"""Assigns the `_value()` method to a TFPolicy.
This way, Policy can call `_value()` to get the current VF estimate on a
single(!) observation (as done in `postprocess_trajectory_fn`).
Note: When doing this, an actual forward pass is being performed.
This is different from only calling `model.value_function()`, where
the result of the most recent forward pass is being used to return an
already calculated tensor.
"""
def __init__(self, config):
# When doing GAE, we need the value function estimate on the
# observation.
if config["use_gae"]:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
@make_tf_callable(self.get_session())
def value(**input_dict):
input_dict = SampleBatch(input_dict)
if isinstance(self.model, tf.keras.Model):
_, _, extra_outs = self.model(input_dict)
return extra_outs[SampleBatch.VF_PREDS][0]
else:
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0]
# When not doing GAE, we do not require the value function's output.
else:
@make_tf_callable(self.get_session())
def value(*args, **kwargs):
return tf.constant(0.0)
self._value = value
self._should_cache_extra_action = config["framework"] == "tf"
self._cached_extra_action_fetches = None
def _extra_action_out_impl(self) -> Dict[str, TensorType]:
extra_action_out = super().extra_action_out_fn()
# Keras models return values for each call in third return argument
# (dict).
if isinstance(self.model, tf.keras.Model):
return extra_action_out
# Return value function outputs. VF estimates will hence be added to the
# SampleBatches produced by the sampler(s) to generate the train batches
# going into the loss function.
extra_action_out.update(
{
SampleBatch.VF_PREDS: self.model.value_function(),
}
)
return extra_action_out
def extra_action_out_fn(self) -> Dict[str, TensorType]:
if not self._should_cache_extra_action:
return self._extra_action_out_impl()
# Note: there are 2 reasons we are caching the extra_action_fetches for
# TF1 static graph here.
# 1. for better performance, so we don't query base class and model for
# extra fetches every single time.
# 2. for correctness. TF1 is special because the static graph may contain
# two logical graphs. One created by DynamicTFPolicy for action
# computation, and one created by MultiGPUTower for GPU training.
# Depending on which logical graph ran last time,
# self.model.value_function() will point to the output tensor
# of the specific logical graph, causing problem if we try to
# fetch action (run inference) using the training output tensor.
# For that reason, we cache the action output tensor from the
# vanilla DynamicTFPolicy once and call it a day.
if self._cached_extra_action_fetches is not None:
return self._cached_extra_action_fetches
self._cached_extra_action_fetches = self._extra_action_out_impl()
return self._cached_extra_action_fetches
# TODO: find a better place for this util, since it's not technically MixIns.
@DeveloperAPI
def compute_gradients(
policy, optimizer: LocalOptimizer, loss: TensorType
) -> ModelGradients:
# Compute the gradients.
variables = policy.model.trainable_variables
if isinstance(policy.model, ModelV2):
variables = variables()
grads_and_vars = optimizer.compute_gradients(loss, variables)
# Clip by global norm, if necessary.
if policy.config.get("grad_clip") is not None:
# Defuse inf gradients (due to super large losses).
grads = [g for (g, v) in grads_and_vars]
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
# If the global_norm is inf -> All grads will be NaN. Stabilize this
# here by setting them to 0.0. This will simply ignore destructive loss
# calculations.
policy.grads = []
for g in grads:
if g is not None:
policy.grads.append(tf.where(tf.math.is_nan(g), tf.zeros_like(g), g))
else:
policy.grads.append(None)
clipped_grads_and_vars = list(zip(policy.grads, variables))
return clipped_grads_and_vars
else:
return grads_and_vars