[RLlib] DQN rainbow eager-mode (keras style NoisyLayer) (preparation for native tf2.x support). (#9304)

This commit is contained in:
Sven Mika 2020-07-09 10:44:10 +02:00 committed by GitHub
parent c37d30a079
commit 01125b8fcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 87 additions and 166 deletions

View file

@ -425,7 +425,7 @@ py_test(
py_test(
name = "test_dqn",
tags = ["agents_dir"],
size = "medium",
size = "large",
srcs = ["agents/dqn/tests/test_dqn.py"]
)
py_test(

View file

@ -1,5 +1,4 @@
import numpy as np
from ray.rllib.models.tf.layers import NoisyLayer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
@ -69,13 +68,15 @@ class DistributionalQTFModel(TFModelV2):
self.model_out = tf.keras.layers.Input(
shape=(num_outputs, ), name="model_out")
def build_action_value(model_out):
def build_action_value(prefix, model_out):
if q_hiddens:
action_out = model_out
for i in range(len(q_hiddens)):
if use_noisy:
action_out = self._noisy_layer(
"hidden_%d" % i, action_out, q_hiddens[i], sigma0)
action_out = NoisyLayer(
"{}hidden_{}".format(prefix, i),
q_hiddens[i],
sigma0)(action_out)
elif add_layer_norm:
action_out = tf.keras.layers.Dense(
units=q_hiddens[i],
@ -94,12 +95,11 @@ class DistributionalQTFModel(TFModelV2):
action_out = model_out
if use_noisy:
action_scores = self._noisy_layer(
"output",
action_out,
action_scores = NoisyLayer(
"{}output".format(prefix),
self.action_space.n * num_atoms,
sigma0,
non_linear=False)
activation=None)(action_out)
elif q_hiddens:
action_scores = tf.keras.layers.Dense(
units=self.action_space.n * num_atoms,
@ -130,13 +130,14 @@ class DistributionalQTFModel(TFModelV2):
dist = tf.expand_dims(tf.ones_like(action_scores), -1)
return [action_scores, logits, dist]
def build_state_score(model_out):
def build_state_score(prefix, model_out):
state_out = model_out
for i in range(len(q_hiddens)):
if use_noisy:
state_out = self._noisy_layer("dueling_hidden_%d" % i,
state_out, q_hiddens[i],
sigma0)
state_out = NoisyLayer(
"{}dueling_hidden_{}".format(prefix, i),
q_hiddens[i],
sigma0)(state_out)
else:
state_out = tf.keras.layers.Dense(
units=q_hiddens[i], activation=tf.nn.relu)(state_out)
@ -144,59 +145,23 @@ class DistributionalQTFModel(TFModelV2):
state_out = tf.keras.layers.LayerNormalization()(
state_out)
if use_noisy:
state_score = self._noisy_layer(
"dueling_output",
state_out,
state_score = NoisyLayer(
"{}dueling_output".format(prefix),
num_atoms,
sigma0,
non_linear=False)
activation=None)(state_out)
else:
state_score = tf.keras.layers.Dense(
units=num_atoms, activation=None)(state_out)
return state_score
if tf1.executing_eagerly():
from tensorflow.python.ops import variable_scope
# Have to use a variable store to reuse variables in eager mode
store = variable_scope.EagerVariableStore()
# Save the scope objects, since in eager we will execute this
# path repeatedly and there is no guarantee it will always be run
# in the same original scope.
with tf1.variable_scope(name + "/action_value") as action_scope:
pass
with tf1.variable_scope(name + "/state_value") as state_scope:
pass
def build_action_value_in_scope(model_out):
with store.as_default():
with tf1.variable_scope(
action_scope, reuse=tf1.AUTO_REUSE):
return build_action_value(model_out)
def build_state_score_in_scope(model_out):
with store.as_default():
with tf1.variable_scope(
state_scope, reuse=tf1.AUTO_REUSE):
return build_state_score(model_out)
else:
def build_action_value_in_scope(model_out):
with tf1.variable_scope(
name + "/action_value", reuse=tf1.AUTO_REUSE):
return build_action_value(model_out)
def build_state_score_in_scope(model_out):
with tf1.variable_scope(
name + "/state_value", reuse=tf1.AUTO_REUSE):
return build_state_score(model_out)
q_out = build_action_value_in_scope(self.model_out)
q_out = build_action_value(name + "/action_value/", self.model_out)
self.q_value_head = tf.keras.Model(self.model_out, q_out)
self.register_variables(self.q_value_head.variables)
if dueling:
state_out = build_state_score_in_scope(self.model_out)
state_out = build_state_score(
name + "/state_value/", self.model_out)
self.state_value_head = tf.keras.Model(self.model_out, state_out)
self.register_variables(self.state_value_head.variables)
@ -219,66 +184,3 @@ class DistributionalQTFModel(TFModelV2):
"""Returns the state value prediction for the given state embedding."""
return self.state_value_head(model_out)
def _noisy_layer(self,
prefix,
action_in,
out_size,
sigma0,
non_linear=True):
"""
a common dense layer: y = w^{T}x + b
a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x +
(b+\\epsilon_b*\\sigma_b)
where \epsilon are random variables sampled from factorized normal
distributions and \\sigma are trainable variables which are expected to
vanish along the training procedure
"""
in_size = int(action_in.shape[1])
epsilon_in = tf.random.normal(shape=[in_size])
epsilon_out = tf.random.normal(shape=[out_size])
epsilon_in = self._f_epsilon(epsilon_in)
epsilon_out = self._f_epsilon(epsilon_out)
epsilon_w = tf.matmul(
a=tf.expand_dims(epsilon_in, -1), b=tf.expand_dims(epsilon_out, 0))
epsilon_b = epsilon_out
sigma_w = tf1.get_variable(
name=prefix + "_sigma_w",
shape=[in_size, out_size],
dtype=tf.float32,
initializer=tf1.random_uniform_initializer(
minval=-1.0 / np.sqrt(float(in_size)),
maxval=1.0 / np.sqrt(float(in_size))))
# TF noise generation can be unreliable on GPU
# If generating the noise on the CPU,
# lowering sigma0 to 0.1 may be helpful
sigma_b = tf1.get_variable(
name=prefix + "_sigma_b",
shape=[out_size],
dtype=tf.float32, # 0.5~GPU, 0.1~CPU
initializer=tf1.constant_initializer(
sigma0 / np.sqrt(float(in_size))))
w = tf1.get_variable(
name=prefix + "_fc_w",
shape=[in_size, out_size],
dtype=tf.float32,
initializer=tf.initializers.GlorotUniform())
b = tf1.get_variable(
name=prefix + "_fc_b",
shape=[out_size],
dtype=tf.float32,
initializer=tf.initializers.Zeros())
action_activation = \
tf.keras.layers.Lambda(lambda x: tf.matmul(
x, w + sigma_w * epsilon_w) + b + sigma_b * epsilon_b)(
action_in)
if not non_linear:
return action_activation
return tf.nn.relu(action_activation)
def _f_epsilon(self, x):
return tf.math.sign(x) * tf.math.sqrt(tf.math.abs(x))

View file

@ -338,8 +338,8 @@ class NullContextManager:
@DeveloperAPI
def flatten(obs, framework):
"""Flatten the given tensor."""
if framework == "tf":
return tf1.layers.flatten(obs)
if framework in ["tf", "tfe"]:
return tf1.keras.layers.Flatten()(obs)
elif framework == "torch":
assert torch is not None
return torch.flatten(obs, start_dim=1)

View file

@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
class GRUGate(tf.keras.layers.Layer):
class GRUGate(tf.keras.layers.Layer if tf else object):
def __init__(self, init_bias=0., **kwargs):
super().__init__(**kwargs)
self._init_bias = init_bias

View file

@ -8,7 +8,7 @@ from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
class MultiHeadAttention(tf.keras.layers.Layer):
class MultiHeadAttention(tf.keras.layers.Layer if tf else object):
"""A multi-head attention layer described in [1]."""
def __init__(self, out_dim, num_heads, head_dim, **kwargs):

View file

@ -6,7 +6,7 @@ from ray.rllib.utils.framework import get_activation_fn, get_variable, \
tf1, tf, tfv = try_import_tf()
class NoisyLayer(tf.keras.layers.Layer):
class NoisyLayer(tf.keras.layers.Layer if tf else object):
"""A Layer that adds learnable Noise
a common dense layer: y = w^{T}x + b
a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x +

View file

@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
class RelativeMultiHeadAttention(tf.keras.layers.Layer):
class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
"""A RelativeMultiHeadAttention layer as described in [3].
Uses segment level recurrence with state reuse.

View file

@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
class SkipConnection(tf.keras.layers.Layer):
class SkipConnection(tf.keras.layers.Layer if tf else object):
"""Skip connection layer.
Adds the original input to the output (regular residual layer) OR uses

View file

@ -200,7 +200,7 @@ def build_eager_tf_policy(name,
class eager_policy_cls(base):
def __init__(self, observation_space, action_space, config):
assert tf.executing_eagerly()
self.framework = "tf"
self.framework = "tfe"
Policy.__init__(self, observation_space, action_space, config)
self._is_training = False
self._loss_initialized = False
@ -235,7 +235,7 @@ def build_eager_tf_policy(name,
action_space,
logit_dim,
config["model"],
framework="tf",
framework=self.framework,
)
self.exploration = self._create_exploration()
self._state_in = [
@ -352,7 +352,8 @@ def build_eager_tf_policy(name,
self.model,
input_dict[SampleBatch.CUR_OBS],
explore=explore,
timestep=timestep)
timestep=timestep,
episodes=episodes)
else:
# Exploration hook before each forward pass.
self.exploration.before_compute_actions(
@ -457,8 +458,10 @@ def build_eager_tf_policy(name,
return _convert_to_numpy(self.exploration.get_info())
@override(Policy)
def get_weights(self):
def get_weights(self, as_dict=False):
variables = self.variables()
if as_dict:
return {v.name: v.numpy() for v in variables}
return [v.numpy() for v in variables]
@override(Policy)
@ -638,8 +641,8 @@ def build_eager_tf_policy(name,
dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
# Convert everything to tensors.
dummy_batch = tf.nest.map_structure(tf1.convert_to_tensor,
dummy_batch)
dummy_batch = tf.nest.map_structure(
tf1.convert_to_tensor, dummy_batch)
# for IMPALA which expects a certain sample batch size.
def tile_to(tensor, n):
@ -650,6 +653,11 @@ def build_eager_tf_policy(name,
dummy_batch = tf.nest.map_structure(
lambda c: tile_to(c, get_batch_divisibility_req(self)),
dummy_batch)
i = 0
self._state_in = []
while "state_in_{}".format(i) in dummy_batch:
self._state_in.append(dummy_batch["state_in_{}".format(i)])
i += 1
# Execute a forward pass to get self.action_dist etc initialized,
# and also obtain the extra action fetches

View file

@ -57,7 +57,7 @@ class EpsilonGreedy(Exploration):
0, framework=framework, tf_name="timestep")
# Build the tf-info-op.
if self.framework == "tf":
if self.framework in ["tf", "tfe"]:
self._tf_info_op = self.get_info()
@override(Exploration)
@ -68,7 +68,7 @@ class EpsilonGreedy(Exploration):
explore: bool = True):
q_values = action_distribution.inputs
if self.framework == "tf":
if self.framework in ["tf", "tfe"]:
return self._get_tf_exploration_action_op(q_values, explore,
timestep)
else:

View file

@ -290,10 +290,9 @@ class ParameterNoise(Exploration):
def _sample_new_noise(self, *, tf_sess=None):
"""Samples new noise and stores it in `self.noise`."""
if self.framework == "tf":
if tf.executing_eagerly():
self._tf_sample_new_noise_op()
else:
tf_sess.run(self.tf_sample_new_noise_op)
tf_sess.run(self.tf_sample_new_noise_op)
elif self.framework == "tfe":
self._tf_sample_new_noise_op()
else:
for i in range(len(self.noise)):
self.noise[i] = torch.normal(
@ -312,7 +311,7 @@ class ParameterNoise(Exploration):
return tf.group(*added_noises)
def _sample_new_noise_and_add(self, *, tf_sess=None, override=False):
if self.framework == "tf" and not tf.executing_eagerly():
if self.framework == "tf":
if override and self.weights_are_currently_noisy:
tf_sess.run(self.tf_remove_noise_op)
tf_sess.run(self.tf_sample_new_noise_and_add_op)
@ -338,12 +337,11 @@ class ParameterNoise(Exploration):
# Make sure we only add noise to currently noise-free weights.
assert self.weights_are_currently_noisy is False
if self.framework == "tf":
if tf.executing_eagerly():
self._tf_add_stored_noise_op()
else:
tf_sess.run(self.tf_add_stored_noise_op)
# Add stored noise to the model's parameters.
if self.framework == "tf":
tf_sess.run(self.tf_add_stored_noise_op)
elif self.framework == "tfe":
self._tf_add_stored_noise_op()
else:
for i in range(len(self.noise)):
# Add noise to weights in-place.
@ -377,13 +375,12 @@ class ParameterNoise(Exploration):
# Make sure we only remove noise iff currently noisy.
assert self.weights_are_currently_noisy is True
# Removes the stored noise from the model's parameters.
if self.framework == "tf":
if tf.executing_eagerly():
self._tf_remove_noise_op()
else:
tf_sess.run(self.tf_remove_noise_op)
tf_sess.run(self.tf_remove_noise_op)
elif self.framework == "tfe":
self._tf_remove_noise_op()
else:
# Removes the stored noise from the model's parameters.
for var, noise in zip(self.model_variables, self.noise):
# Remove noise from weights in-place.
var.add_(-noise)

View file

@ -28,7 +28,7 @@ class Random(Exploration):
Args:
action_space (Space): The gym action space used by the environment.
framework (Optional[str]): One of None, "tf", "torch".
framework (Optional[str]): One of None, "tf", "tfe", "torch".
"""
super().__init__(
action_space=action_space,
@ -46,7 +46,7 @@ class Random(Exploration):
timestep: Union[int, TensorType],
explore: bool = True):
# Instantiate the distribution object.
if self.framework == "tf":
if self.framework in ["tf", "tfe"]:
return self.get_tf_exploration_action_op(action_distribution,
explore)
else:

View file

@ -1,17 +1,17 @@
import logging
import os
import sys
from typing import Any, Union
from typing import Any, Optional
from ray.rllib.utils.types import TensorStructType, TensorShape, TensorType
logger = logging.getLogger(__name__)
# Represents a generic tensor type.
# TODO(ekl) this is duplicated in types.py
TensorType = Any
TensorType = TensorType
# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
# TODO(ekl) this is duplicated in types.py
TensorStructType = Union[TensorType, dict, tuple]
TensorStructType = TensorStructType
def try_import_tf(error=False):
@ -39,6 +39,9 @@ def try_import_tf(error=False):
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# TODO: (sven) Allow env var to force compat.v1 behavior even if tf2.x
# installed.
# Try to reuse already imported tf module. This will avoid going through
# the initial import steps below and thereby switching off v2_behavior
# (switching off v2 behavior twice breaks all-framework tests for eager).
@ -160,15 +163,18 @@ def _torch_stubs():
def get_variable(value,
framework="tf",
trainable=False,
tf_name="unnamed-variable",
torch_tensor=False,
device=None):
framework: str = "tf",
trainable: bool = False,
tf_name: str = "unnamed-variable",
torch_tensor: bool = False,
device: Optional[str] = None,
shape: Optional[TensorShape] = None,
dtype: Optional[Any] = None):
"""
Args:
value (any): The initial value to use. In the non-tf case, this will
be returned as is.
be returned as is. In the tf case, this could be a tf-Initializer
object.
framework (str): One of "tf", "torch", or None.
trainable (bool): Whether the generated variable should be
trainable (tf)/require_grad (torch) or not (default: False).
@ -176,19 +182,27 @@ def get_variable(value,
tf.Variable.
torch_tensor (bool): For framework="torch": Whether to actually create
a torch.tensor, or just a python value (default).
device (Optional[torch.Device]): An optional torch device to use for
the created torch tensor.
shape (Optional[TensorShape]): An optional shape to use iff `value`
does not have any (e.g. if it's an initializer w/o explicit value).
dtype (Optional[TensorType]): An optional dtype to use iff `value` does
not have any (e.g. if it's an initializer w/o explicit value).
Returns:
any: A framework-specific variable (tf.Variable, torch.tensor, or
python primitive).
"""
if framework == "tf":
if framework in ["tf", "tfe"]:
import tensorflow as tf
dtype = getattr(
dtype = dtype or getattr(
value, "dtype", tf.float32
if isinstance(value, float) else tf.int32
if isinstance(value, int) else None)
return tf.compat.v1.get_variable(
tf_name, initializer=value, dtype=dtype, trainable=trainable)
tf_name, initializer=value, dtype=dtype, trainable=trainable,
**({} if shape is None else {"shape": shape})
)
elif framework == "torch" and torch_tensor is True:
torch, _ = try_import_torch()
var_ = torch.from_numpy(value)

View file

@ -35,7 +35,7 @@ class Schedule(metaclass=ABCMeta):
Returns:
any: The calculated value depending on the schedule and `t`.
"""
if self.framework == "tf" and not tf.executing_eagerly():
if self.framework in ["tf", "tfe"]:
return self._tf_value_op(t)
return self._value(t)