mirror of
https://github.com/vale981/ray
synced 2025-03-10 21:36:39 -04:00
59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
![]() |
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
from ray.rllib.utils import try_import_tf
|
||
|
|
||
|
tf = try_import_tf()
|
||
|
|
||
|
|
||
|
def huber_loss(x, delta=1.0):
|
||
|
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
||
|
return tf.where(
|
||
|
tf.abs(x) < delta,
|
||
|
tf.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta))
|
||
|
|
||
|
|
||
|
def reduce_mean_ignore_inf(x, axis):
|
||
|
"""Same as tf.reduce_mean() but ignores -inf values."""
|
||
|
mask = tf.not_equal(x, tf.float32.min)
|
||
|
x_zeroed = tf.where(mask, x, tf.zeros_like(x))
|
||
|
return (tf.reduce_sum(x_zeroed, axis) / tf.reduce_sum(
|
||
|
tf.cast(mask, tf.float32), axis))
|
||
|
|
||
|
|
||
|
def minimize_and_clip(optimizer, objective, var_list, clip_val=10):
|
||
|
"""Minimized `objective` using `optimizer` w.r.t. variables in
|
||
|
`var_list` while ensure the norm of the gradients for each
|
||
|
variable is clipped to `clip_val`
|
||
|
"""
|
||
|
gradients = optimizer.compute_gradients(objective, var_list=var_list)
|
||
|
for i, (grad, var) in enumerate(gradients):
|
||
|
if grad is not None:
|
||
|
gradients[i] = (tf.clip_by_norm(grad, clip_val), var)
|
||
|
return gradients
|
||
|
|
||
|
|
||
|
def scope_vars(scope, trainable_only=False):
|
||
|
"""
|
||
|
Get variables inside a scope
|
||
|
The scope can be specified as a string
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
scope: str or VariableScope
|
||
|
scope in which the variables reside.
|
||
|
trainable_only: bool
|
||
|
whether or not to return only the variables that were marked as
|
||
|
trainable.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
vars: [tf.Variable]
|
||
|
list of variables in `scope`.
|
||
|
"""
|
||
|
return tf.get_collection(
|
||
|
tf.GraphKeys.TRAINABLE_VARIABLES
|
||
|
if trainable_only else tf.GraphKeys.VARIABLES,
|
||
|
scope=scope if isinstance(scope, str) else scope.name)
|