2022-05-25 05:38:03 -07:00
|
|
|
import gym
|
|
|
|
import logging
|
2022-05-17 08:16:08 -07:00
|
|
|
from typing import Dict, List, Union
|
|
|
|
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.policy.tf_policy import TFPolicy
|
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI, override
|
|
|
|
from ray.rllib.utils.framework import try_import_tf, get_variable
|
|
|
|
from ray.rllib.utils.schedules import PiecewiseSchedule
|
|
|
|
from ray.rllib.utils.tf_utils import make_tf_callable
|
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
LocalOptimizer,
|
|
|
|
ModelGradients,
|
|
|
|
TensorType,
|
2022-05-25 05:38:03 -07:00
|
|
|
TrainerConfigDict,
|
2022-05-17 08:16:08 -07:00
|
|
|
)
|
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2022-05-17 08:16:08 -07:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
class LearningRateSchedule:
|
|
|
|
"""Mixin for TFPolicy that adds a learning rate schedule."""
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def __init__(self, lr, lr_schedule):
|
|
|
|
self._lr_schedule = None
|
|
|
|
if lr_schedule is None:
|
|
|
|
self.cur_lr = tf1.get_variable("lr", initializer=lr, trainable=False)
|
|
|
|
else:
|
|
|
|
self._lr_schedule = PiecewiseSchedule(
|
|
|
|
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None
|
|
|
|
)
|
|
|
|
self.cur_lr = tf1.get_variable(
|
|
|
|
"lr", initializer=self._lr_schedule.value(0), trainable=False
|
|
|
|
)
|
|
|
|
if self.framework == "tf":
|
|
|
|
self._lr_placeholder = tf1.placeholder(dtype=tf.float32, name="lr")
|
|
|
|
self._lr_update = self.cur_lr.assign(
|
|
|
|
self._lr_placeholder, read_value=False
|
|
|
|
)
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def on_global_var_update(self, global_vars):
|
|
|
|
super().on_global_var_update(global_vars)
|
|
|
|
if self._lr_schedule is not None:
|
|
|
|
new_val = self._lr_schedule.value(global_vars["timestep"])
|
|
|
|
if self.framework == "tf":
|
|
|
|
self.get_session().run(
|
|
|
|
self._lr_update, feed_dict={self._lr_placeholder: new_val}
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.cur_lr.assign(new_val, read_value=False)
|
|
|
|
# This property (self._optimizer) is (still) accessible for
|
|
|
|
# both TFPolicy and any TFPolicy_eager.
|
|
|
|
self._optimizer.learning_rate.assign(self.cur_lr)
|
|
|
|
|
|
|
|
@override(TFPolicy)
|
|
|
|
def optimizer(self):
|
|
|
|
if self.framework == "tf":
|
|
|
|
return tf1.train.AdamOptimizer(learning_rate=self.cur_lr)
|
|
|
|
else:
|
|
|
|
return tf.keras.optimizers.Adam(self.cur_lr)
|
|
|
|
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
class EntropyCoeffSchedule:
|
|
|
|
"""Mixin for TFPolicy that adds entropy coeff decay."""
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def __init__(self, entropy_coeff, entropy_coeff_schedule):
|
|
|
|
self._entropy_coeff_schedule = None
|
|
|
|
if entropy_coeff_schedule is None:
|
|
|
|
self.entropy_coeff = get_variable(
|
|
|
|
entropy_coeff, framework="tf", tf_name="entropy_coeff", trainable=False
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# Allows for custom schedule similar to lr_schedule format
|
|
|
|
if isinstance(entropy_coeff_schedule, list):
|
|
|
|
self._entropy_coeff_schedule = PiecewiseSchedule(
|
|
|
|
entropy_coeff_schedule,
|
|
|
|
outside_value=entropy_coeff_schedule[-1][-1],
|
|
|
|
framework=None,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# Implements previous version but enforces outside_value
|
|
|
|
self._entropy_coeff_schedule = PiecewiseSchedule(
|
|
|
|
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
|
|
|
|
outside_value=0.0,
|
|
|
|
framework=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.entropy_coeff = get_variable(
|
|
|
|
self._entropy_coeff_schedule.value(0),
|
|
|
|
framework="tf",
|
|
|
|
tf_name="entropy_coeff",
|
|
|
|
trainable=False,
|
|
|
|
)
|
|
|
|
if self.framework == "tf":
|
|
|
|
self._entropy_coeff_placeholder = tf1.placeholder(
|
|
|
|
dtype=tf.float32, name="entropy_coeff"
|
|
|
|
)
|
|
|
|
self._entropy_coeff_update = self.entropy_coeff.assign(
|
|
|
|
self._entropy_coeff_placeholder, read_value=False
|
|
|
|
)
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def on_global_var_update(self, global_vars):
|
|
|
|
super().on_global_var_update(global_vars)
|
|
|
|
if self._entropy_coeff_schedule is not None:
|
|
|
|
new_val = self._entropy_coeff_schedule.value(global_vars["timestep"])
|
|
|
|
if self.framework == "tf":
|
|
|
|
self.get_session().run(
|
|
|
|
self._entropy_coeff_update,
|
|
|
|
feed_dict={self._entropy_coeff_placeholder: new_val},
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.entropy_coeff.assign(new_val, read_value=False)
|
|
|
|
|
|
|
|
|
|
|
|
class KLCoeffMixin:
|
|
|
|
"""Assigns the `update_kl()` and other KL-related methods to a TFPolicy.
|
|
|
|
|
|
|
|
This is used in Trainers to update the KL coefficient after each
|
|
|
|
learning step based on `config.kl_target` and the measured KL value
|
|
|
|
(from the train_batch).
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
# The current KL value (as python float).
|
|
|
|
self.kl_coeff_val = config["kl_coeff"]
|
|
|
|
# The current KL value (as tf Variable for in-graph operations).
|
|
|
|
self.kl_coeff = get_variable(
|
|
|
|
float(self.kl_coeff_val),
|
|
|
|
tf_name="kl_coeff",
|
|
|
|
trainable=False,
|
|
|
|
framework=config["framework"],
|
|
|
|
)
|
|
|
|
# Constant target value.
|
|
|
|
self.kl_target = config["kl_target"]
|
|
|
|
if self.framework == "tf":
|
|
|
|
self._kl_coeff_placeholder = tf1.placeholder(
|
|
|
|
dtype=tf.float32, name="kl_coeff"
|
|
|
|
)
|
|
|
|
self._kl_coeff_update = self.kl_coeff.assign(
|
|
|
|
self._kl_coeff_placeholder, read_value=False
|
|
|
|
)
|
|
|
|
|
|
|
|
def update_kl(self, sampled_kl):
|
|
|
|
# Update the current KL value based on the recently measured value.
|
|
|
|
# Increase.
|
|
|
|
if sampled_kl > 2.0 * self.kl_target:
|
|
|
|
self.kl_coeff_val *= 1.5
|
|
|
|
# Decrease.
|
|
|
|
elif sampled_kl < 0.5 * self.kl_target:
|
|
|
|
self.kl_coeff_val *= 0.5
|
|
|
|
# No change.
|
|
|
|
else:
|
|
|
|
return self.kl_coeff_val
|
|
|
|
|
|
|
|
# Make sure, new value is also stored in graph/tf variable.
|
|
|
|
self._set_kl_coeff(self.kl_coeff_val)
|
|
|
|
|
|
|
|
# Return the current KL value.
|
|
|
|
return self.kl_coeff_val
|
|
|
|
|
|
|
|
def _set_kl_coeff(self, new_kl_coeff):
|
|
|
|
# Set the (off graph) value.
|
|
|
|
self.kl_coeff_val = new_kl_coeff
|
|
|
|
|
|
|
|
# Update the tf/tf2 Variable (via session call for tf or `assign`).
|
|
|
|
if self.framework == "tf":
|
|
|
|
self.get_session().run(
|
|
|
|
self._kl_coeff_update,
|
|
|
|
feed_dict={self._kl_coeff_placeholder: self.kl_coeff_val},
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.kl_coeff.assign(self.kl_coeff_val, read_value=False)
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
|
|
|
state = super().get_state()
|
|
|
|
# Add current kl-coeff value.
|
|
|
|
state["current_kl_coeff"] = self.kl_coeff_val
|
|
|
|
return state
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def set_state(self, state: dict) -> None:
|
|
|
|
# Set current kl-coeff value first.
|
|
|
|
self._set_kl_coeff(state.pop("current_kl_coeff", self.config["kl_coeff"]))
|
|
|
|
# Call super's set_state with rest of the state dict.
|
|
|
|
super().set_state(state)
|
|
|
|
|
|
|
|
|
|
|
|
class ValueNetworkMixin:
|
|
|
|
"""Assigns the `_value()` method to a TFPolicy.
|
|
|
|
|
|
|
|
This way, Policy can call `_value()` to get the current VF estimate on a
|
|
|
|
single(!) observation (as done in `postprocess_trajectory_fn`).
|
|
|
|
Note: When doing this, an actual forward pass is being performed.
|
|
|
|
This is different from only calling `model.value_function()`, where
|
|
|
|
the result of the most recent forward pass is being used to return an
|
|
|
|
already calculated tensor.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
# When doing GAE, we need the value function estimate on the
|
|
|
|
# observation.
|
|
|
|
if config["use_gae"]:
|
|
|
|
|
|
|
|
# Input dict is provided to us automatically via the Model's
|
|
|
|
# requirements. It's a single-timestep (last one in trajectory)
|
|
|
|
# input_dict.
|
|
|
|
@make_tf_callable(self.get_session())
|
|
|
|
def value(**input_dict):
|
|
|
|
input_dict = SampleBatch(input_dict)
|
|
|
|
if isinstance(self.model, tf.keras.Model):
|
|
|
|
_, _, extra_outs = self.model(input_dict)
|
|
|
|
return extra_outs[SampleBatch.VF_PREDS][0]
|
|
|
|
else:
|
|
|
|
model_out, _ = self.model(input_dict)
|
|
|
|
# [0] = remove the batch dim.
|
|
|
|
return self.model.value_function()[0]
|
|
|
|
|
|
|
|
# When not doing GAE, we do not require the value function's output.
|
|
|
|
else:
|
|
|
|
|
|
|
|
@make_tf_callable(self.get_session())
|
|
|
|
def value(*args, **kwargs):
|
|
|
|
return tf.constant(0.0)
|
|
|
|
|
|
|
|
self._value = value
|
2022-05-25 05:38:03 -07:00
|
|
|
self._should_cache_extra_action = config["framework"] == "tf"
|
2022-05-17 08:16:08 -07:00
|
|
|
self._cached_extra_action_fetches = None
|
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
def _extra_action_out_impl(self) -> Dict[str, TensorType]:
|
|
|
|
extra_action_out = super().extra_action_out_fn()
|
|
|
|
# Keras models return values for each call in third return argument
|
|
|
|
# (dict).
|
|
|
|
if isinstance(self.model, tf.keras.Model):
|
|
|
|
return extra_action_out
|
|
|
|
# Return value function outputs. VF estimates will hence be added to the
|
|
|
|
# SampleBatches produced by the sampler(s) to generate the train batches
|
|
|
|
# going into the loss function.
|
|
|
|
extra_action_out.update(
|
|
|
|
{
|
|
|
|
SampleBatch.VF_PREDS: self.model.value_function(),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
return extra_action_out
|
|
|
|
|
2022-05-17 08:16:08 -07:00
|
|
|
def extra_action_out_fn(self) -> Dict[str, TensorType]:
|
2022-05-25 05:38:03 -07:00
|
|
|
if not self._should_cache_extra_action:
|
|
|
|
return self._extra_action_out_impl()
|
|
|
|
|
|
|
|
# Note: there are 2 reasons we are caching the extra_action_fetches for
|
|
|
|
# TF1 static graph here.
|
2022-05-17 08:16:08 -07:00
|
|
|
# 1. for better performance, so we don't query base class and model for
|
|
|
|
# extra fetches every single time.
|
|
|
|
# 2. for correctness. TF1 is special because the static graph may contain
|
|
|
|
# two logical graphs. One created by DynamicTFPolicy for action
|
|
|
|
# computation, and one created by MultiGPUTower for GPU training.
|
|
|
|
# Depending on which logical graph ran last time,
|
|
|
|
# self.model.value_function() will point to the output tensor
|
|
|
|
# of the specific logical graph, causing problem if we try to
|
|
|
|
# fetch action (run inference) using the training output tensor.
|
|
|
|
# For that reason, we cache the action output tensor from the
|
|
|
|
# vanilla DynamicTFPolicy once and call it a day.
|
|
|
|
if self._cached_extra_action_fetches is not None:
|
|
|
|
return self._cached_extra_action_fetches
|
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
self._cached_extra_action_fetches = self._extra_action_out_impl()
|
2022-05-17 08:16:08 -07:00
|
|
|
return self._cached_extra_action_fetches
|
|
|
|
|
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
class TargetNetworkMixin:
|
|
|
|
"""Assign the `update_target` method to the SimpleQTFPolicy
|
2022-05-17 08:16:08 -07:00
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
The function is called every `target_network_update_freq` steps by the
|
|
|
|
master learner.
|
|
|
|
"""
|
2022-05-17 08:16:08 -07:00
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
2022-05-17 08:16:08 -07:00
|
|
|
):
|
2022-05-25 05:38:03 -07:00
|
|
|
@make_tf_callable(self.get_session())
|
|
|
|
def do_update():
|
|
|
|
# update_target_fn will be called periodically to copy Q network to
|
|
|
|
# target Q network
|
|
|
|
update_target_expr = []
|
|
|
|
assert len(self.q_func_vars) == len(self.target_q_func_vars), (
|
|
|
|
self.q_func_vars,
|
|
|
|
self.target_q_func_vars,
|
|
|
|
)
|
|
|
|
for var, var_target in zip(self.q_func_vars, self.target_q_func_vars):
|
|
|
|
update_target_expr.append(var_target.assign(var))
|
|
|
|
logger.debug("Update target op {}".format(var_target))
|
|
|
|
return tf.group(*update_target_expr)
|
2022-05-17 08:16:08 -07:00
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
self.update_target = do_update
|
2022-05-17 08:16:08 -07:00
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
@property
|
|
|
|
def q_func_vars(self):
|
|
|
|
if not hasattr(self, "_q_func_vars"):
|
|
|
|
self._q_func_vars = self.model.variables()
|
|
|
|
return self._q_func_vars
|
2022-05-17 08:16:08 -07:00
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
@property
|
|
|
|
def target_q_func_vars(self):
|
|
|
|
if not hasattr(self, "_target_q_func_vars"):
|
|
|
|
self._target_q_func_vars = self.target_model.variables()
|
|
|
|
return self._target_q_func_vars
|
2022-05-17 08:16:08 -07:00
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
@override(TFPolicy)
|
|
|
|
def variables(self):
|
|
|
|
return self.q_func_vars + self.target_q_func_vars
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: find a better place for this util, since it's not technically MixIns.
|
|
|
|
@DeveloperAPI
|
|
|
|
def compute_gradients(
|
|
|
|
policy, optimizer: LocalOptimizer, loss: TensorType
|
|
|
|
) -> ModelGradients:
|
|
|
|
# Compute the gradients.
|
|
|
|
variables = policy.model.trainable_variables
|
|
|
|
if isinstance(policy.model, ModelV2):
|
|
|
|
variables = variables()
|
|
|
|
grads_and_vars = optimizer.compute_gradients(loss, variables)
|
|
|
|
|
|
|
|
# Clip by global norm, if necessary.
|
|
|
|
if policy.config["grad_clip"] is not None:
|
|
|
|
# Defuse inf gradients (due to super large losses).
|
|
|
|
grads = [g for (g, v) in grads_and_vars]
|
|
|
|
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
|
|
|
|
# If the global_norm is inf -> All grads will be NaN. Stabilize this
|
|
|
|
# here by setting them to 0.0. This will simply ignore destructive loss
|
|
|
|
# calculations.
|
2022-05-28 09:54:47 +02:00
|
|
|
policy.grads = []
|
|
|
|
for g in grads:
|
|
|
|
if g is not None:
|
|
|
|
policy.grads.append(tf.where(tf.math.is_nan(g), tf.zeros_like(g), g))
|
|
|
|
else:
|
|
|
|
policy.grads.append(None)
|
2022-05-25 05:38:03 -07:00
|
|
|
clipped_grads_and_vars = list(zip(policy.grads, variables))
|
|
|
|
return clipped_grads_and_vars
|
|
|
|
else:
|
|
|
|
return grads_and_vars
|