mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] DQN rainbow eager-mode (keras style NoisyLayer) (preparation for native tf2.x support). (#9304)
This commit is contained in:
parent
c37d30a079
commit
01125b8fcf
14 changed files with 87 additions and 166 deletions
|
@ -425,7 +425,7 @@ py_test(
|
|||
py_test(
|
||||
name = "test_dqn",
|
||||
tags = ["agents_dir"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["agents/dqn/tests/test_dqn.py"]
|
||||
)
|
||||
py_test(
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import numpy as np
|
||||
|
||||
from ray.rllib.models.tf.layers import NoisyLayer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
|
@ -69,13 +68,15 @@ class DistributionalQTFModel(TFModelV2):
|
|||
self.model_out = tf.keras.layers.Input(
|
||||
shape=(num_outputs, ), name="model_out")
|
||||
|
||||
def build_action_value(model_out):
|
||||
def build_action_value(prefix, model_out):
|
||||
if q_hiddens:
|
||||
action_out = model_out
|
||||
for i in range(len(q_hiddens)):
|
||||
if use_noisy:
|
||||
action_out = self._noisy_layer(
|
||||
"hidden_%d" % i, action_out, q_hiddens[i], sigma0)
|
||||
action_out = NoisyLayer(
|
||||
"{}hidden_{}".format(prefix, i),
|
||||
q_hiddens[i],
|
||||
sigma0)(action_out)
|
||||
elif add_layer_norm:
|
||||
action_out = tf.keras.layers.Dense(
|
||||
units=q_hiddens[i],
|
||||
|
@ -94,12 +95,11 @@ class DistributionalQTFModel(TFModelV2):
|
|||
action_out = model_out
|
||||
|
||||
if use_noisy:
|
||||
action_scores = self._noisy_layer(
|
||||
"output",
|
||||
action_out,
|
||||
action_scores = NoisyLayer(
|
||||
"{}output".format(prefix),
|
||||
self.action_space.n * num_atoms,
|
||||
sigma0,
|
||||
non_linear=False)
|
||||
activation=None)(action_out)
|
||||
elif q_hiddens:
|
||||
action_scores = tf.keras.layers.Dense(
|
||||
units=self.action_space.n * num_atoms,
|
||||
|
@ -130,13 +130,14 @@ class DistributionalQTFModel(TFModelV2):
|
|||
dist = tf.expand_dims(tf.ones_like(action_scores), -1)
|
||||
return [action_scores, logits, dist]
|
||||
|
||||
def build_state_score(model_out):
|
||||
def build_state_score(prefix, model_out):
|
||||
state_out = model_out
|
||||
for i in range(len(q_hiddens)):
|
||||
if use_noisy:
|
||||
state_out = self._noisy_layer("dueling_hidden_%d" % i,
|
||||
state_out, q_hiddens[i],
|
||||
sigma0)
|
||||
state_out = NoisyLayer(
|
||||
"{}dueling_hidden_{}".format(prefix, i),
|
||||
q_hiddens[i],
|
||||
sigma0)(state_out)
|
||||
else:
|
||||
state_out = tf.keras.layers.Dense(
|
||||
units=q_hiddens[i], activation=tf.nn.relu)(state_out)
|
||||
|
@ -144,59 +145,23 @@ class DistributionalQTFModel(TFModelV2):
|
|||
state_out = tf.keras.layers.LayerNormalization()(
|
||||
state_out)
|
||||
if use_noisy:
|
||||
state_score = self._noisy_layer(
|
||||
"dueling_output",
|
||||
state_out,
|
||||
state_score = NoisyLayer(
|
||||
"{}dueling_output".format(prefix),
|
||||
num_atoms,
|
||||
sigma0,
|
||||
non_linear=False)
|
||||
activation=None)(state_out)
|
||||
else:
|
||||
state_score = tf.keras.layers.Dense(
|
||||
units=num_atoms, activation=None)(state_out)
|
||||
return state_score
|
||||
|
||||
if tf1.executing_eagerly():
|
||||
from tensorflow.python.ops import variable_scope
|
||||
# Have to use a variable store to reuse variables in eager mode
|
||||
store = variable_scope.EagerVariableStore()
|
||||
|
||||
# Save the scope objects, since in eager we will execute this
|
||||
# path repeatedly and there is no guarantee it will always be run
|
||||
# in the same original scope.
|
||||
with tf1.variable_scope(name + "/action_value") as action_scope:
|
||||
pass
|
||||
with tf1.variable_scope(name + "/state_value") as state_scope:
|
||||
pass
|
||||
|
||||
def build_action_value_in_scope(model_out):
|
||||
with store.as_default():
|
||||
with tf1.variable_scope(
|
||||
action_scope, reuse=tf1.AUTO_REUSE):
|
||||
return build_action_value(model_out)
|
||||
|
||||
def build_state_score_in_scope(model_out):
|
||||
with store.as_default():
|
||||
with tf1.variable_scope(
|
||||
state_scope, reuse=tf1.AUTO_REUSE):
|
||||
return build_state_score(model_out)
|
||||
else:
|
||||
|
||||
def build_action_value_in_scope(model_out):
|
||||
with tf1.variable_scope(
|
||||
name + "/action_value", reuse=tf1.AUTO_REUSE):
|
||||
return build_action_value(model_out)
|
||||
|
||||
def build_state_score_in_scope(model_out):
|
||||
with tf1.variable_scope(
|
||||
name + "/state_value", reuse=tf1.AUTO_REUSE):
|
||||
return build_state_score(model_out)
|
||||
|
||||
q_out = build_action_value_in_scope(self.model_out)
|
||||
q_out = build_action_value(name + "/action_value/", self.model_out)
|
||||
self.q_value_head = tf.keras.Model(self.model_out, q_out)
|
||||
self.register_variables(self.q_value_head.variables)
|
||||
|
||||
if dueling:
|
||||
state_out = build_state_score_in_scope(self.model_out)
|
||||
state_out = build_state_score(
|
||||
name + "/state_value/", self.model_out)
|
||||
self.state_value_head = tf.keras.Model(self.model_out, state_out)
|
||||
self.register_variables(self.state_value_head.variables)
|
||||
|
||||
|
@ -219,66 +184,3 @@ class DistributionalQTFModel(TFModelV2):
|
|||
"""Returns the state value prediction for the given state embedding."""
|
||||
|
||||
return self.state_value_head(model_out)
|
||||
|
||||
def _noisy_layer(self,
|
||||
prefix,
|
||||
action_in,
|
||||
out_size,
|
||||
sigma0,
|
||||
non_linear=True):
|
||||
"""
|
||||
a common dense layer: y = w^{T}x + b
|
||||
a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x +
|
||||
(b+\\epsilon_b*\\sigma_b)
|
||||
where \epsilon are random variables sampled from factorized normal
|
||||
distributions and \\sigma are trainable variables which are expected to
|
||||
vanish along the training procedure
|
||||
"""
|
||||
in_size = int(action_in.shape[1])
|
||||
|
||||
epsilon_in = tf.random.normal(shape=[in_size])
|
||||
epsilon_out = tf.random.normal(shape=[out_size])
|
||||
epsilon_in = self._f_epsilon(epsilon_in)
|
||||
epsilon_out = self._f_epsilon(epsilon_out)
|
||||
epsilon_w = tf.matmul(
|
||||
a=tf.expand_dims(epsilon_in, -1), b=tf.expand_dims(epsilon_out, 0))
|
||||
epsilon_b = epsilon_out
|
||||
sigma_w = tf1.get_variable(
|
||||
name=prefix + "_sigma_w",
|
||||
shape=[in_size, out_size],
|
||||
dtype=tf.float32,
|
||||
initializer=tf1.random_uniform_initializer(
|
||||
minval=-1.0 / np.sqrt(float(in_size)),
|
||||
maxval=1.0 / np.sqrt(float(in_size))))
|
||||
# TF noise generation can be unreliable on GPU
|
||||
# If generating the noise on the CPU,
|
||||
# lowering sigma0 to 0.1 may be helpful
|
||||
sigma_b = tf1.get_variable(
|
||||
name=prefix + "_sigma_b",
|
||||
shape=[out_size],
|
||||
dtype=tf.float32, # 0.5~GPU, 0.1~CPU
|
||||
initializer=tf1.constant_initializer(
|
||||
sigma0 / np.sqrt(float(in_size))))
|
||||
|
||||
w = tf1.get_variable(
|
||||
name=prefix + "_fc_w",
|
||||
shape=[in_size, out_size],
|
||||
dtype=tf.float32,
|
||||
initializer=tf.initializers.GlorotUniform())
|
||||
b = tf1.get_variable(
|
||||
name=prefix + "_fc_b",
|
||||
shape=[out_size],
|
||||
dtype=tf.float32,
|
||||
initializer=tf.initializers.Zeros())
|
||||
|
||||
action_activation = \
|
||||
tf.keras.layers.Lambda(lambda x: tf.matmul(
|
||||
x, w + sigma_w * epsilon_w) + b + sigma_b * epsilon_b)(
|
||||
action_in)
|
||||
|
||||
if not non_linear:
|
||||
return action_activation
|
||||
return tf.nn.relu(action_activation)
|
||||
|
||||
def _f_epsilon(self, x):
|
||||
return tf.math.sign(x) * tf.math.sqrt(tf.math.abs(x))
|
||||
|
|
|
@ -338,8 +338,8 @@ class NullContextManager:
|
|||
@DeveloperAPI
|
||||
def flatten(obs, framework):
|
||||
"""Flatten the given tensor."""
|
||||
if framework == "tf":
|
||||
return tf1.layers.flatten(obs)
|
||||
if framework in ["tf", "tfe"]:
|
||||
return tf1.keras.layers.Flatten()(obs)
|
||||
elif framework == "torch":
|
||||
assert torch is not None
|
||||
return torch.flatten(obs, start_dim=1)
|
||||
|
|
|
@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_tf
|
|||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class GRUGate(tf.keras.layers.Layer):
|
||||
class GRUGate(tf.keras.layers.Layer if tf else object):
|
||||
def __init__(self, init_bias=0., **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_bias = init_bias
|
||||
|
|
|
@ -8,7 +8,7 @@ from ray.rllib.utils.framework import try_import_tf
|
|||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class MultiHeadAttention(tf.keras.layers.Layer):
|
||||
class MultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
||||
"""A multi-head attention layer described in [1]."""
|
||||
|
||||
def __init__(self, out_dim, num_heads, head_dim, **kwargs):
|
||||
|
|
|
@ -6,7 +6,7 @@ from ray.rllib.utils.framework import get_activation_fn, get_variable, \
|
|||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class NoisyLayer(tf.keras.layers.Layer):
|
||||
class NoisyLayer(tf.keras.layers.Layer if tf else object):
|
||||
"""A Layer that adds learnable Noise
|
||||
a common dense layer: y = w^{T}x + b
|
||||
a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x +
|
||||
|
|
|
@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_tf
|
|||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class RelativeMultiHeadAttention(tf.keras.layers.Layer):
|
||||
class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
||||
"""A RelativeMultiHeadAttention layer as described in [3].
|
||||
|
||||
Uses segment level recurrence with state reuse.
|
||||
|
|
|
@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_tf
|
|||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class SkipConnection(tf.keras.layers.Layer):
|
||||
class SkipConnection(tf.keras.layers.Layer if tf else object):
|
||||
"""Skip connection layer.
|
||||
|
||||
Adds the original input to the output (regular residual layer) OR uses
|
||||
|
|
|
@ -200,7 +200,7 @@ def build_eager_tf_policy(name,
|
|||
class eager_policy_cls(base):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
assert tf.executing_eagerly()
|
||||
self.framework = "tf"
|
||||
self.framework = "tfe"
|
||||
Policy.__init__(self, observation_space, action_space, config)
|
||||
self._is_training = False
|
||||
self._loss_initialized = False
|
||||
|
@ -235,7 +235,7 @@ def build_eager_tf_policy(name,
|
|||
action_space,
|
||||
logit_dim,
|
||||
config["model"],
|
||||
framework="tf",
|
||||
framework=self.framework,
|
||||
)
|
||||
self.exploration = self._create_exploration()
|
||||
self._state_in = [
|
||||
|
@ -352,7 +352,8 @@ def build_eager_tf_policy(name,
|
|||
self.model,
|
||||
input_dict[SampleBatch.CUR_OBS],
|
||||
explore=explore,
|
||||
timestep=timestep)
|
||||
timestep=timestep,
|
||||
episodes=episodes)
|
||||
else:
|
||||
# Exploration hook before each forward pass.
|
||||
self.exploration.before_compute_actions(
|
||||
|
@ -457,8 +458,10 @@ def build_eager_tf_policy(name,
|
|||
return _convert_to_numpy(self.exploration.get_info())
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self):
|
||||
def get_weights(self, as_dict=False):
|
||||
variables = self.variables()
|
||||
if as_dict:
|
||||
return {v.name: v.numpy() for v in variables}
|
||||
return [v.numpy() for v in variables]
|
||||
|
||||
@override(Policy)
|
||||
|
@ -638,8 +641,8 @@ def build_eager_tf_policy(name,
|
|||
dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
|
||||
|
||||
# Convert everything to tensors.
|
||||
dummy_batch = tf.nest.map_structure(tf1.convert_to_tensor,
|
||||
dummy_batch)
|
||||
dummy_batch = tf.nest.map_structure(
|
||||
tf1.convert_to_tensor, dummy_batch)
|
||||
|
||||
# for IMPALA which expects a certain sample batch size.
|
||||
def tile_to(tensor, n):
|
||||
|
@ -650,6 +653,11 @@ def build_eager_tf_policy(name,
|
|||
dummy_batch = tf.nest.map_structure(
|
||||
lambda c: tile_to(c, get_batch_divisibility_req(self)),
|
||||
dummy_batch)
|
||||
i = 0
|
||||
self._state_in = []
|
||||
while "state_in_{}".format(i) in dummy_batch:
|
||||
self._state_in.append(dummy_batch["state_in_{}".format(i)])
|
||||
i += 1
|
||||
|
||||
# Execute a forward pass to get self.action_dist etc initialized,
|
||||
# and also obtain the extra action fetches
|
||||
|
|
|
@ -57,7 +57,7 @@ class EpsilonGreedy(Exploration):
|
|||
0, framework=framework, tf_name="timestep")
|
||||
|
||||
# Build the tf-info-op.
|
||||
if self.framework == "tf":
|
||||
if self.framework in ["tf", "tfe"]:
|
||||
self._tf_info_op = self.get_info()
|
||||
|
||||
@override(Exploration)
|
||||
|
@ -68,7 +68,7 @@ class EpsilonGreedy(Exploration):
|
|||
explore: bool = True):
|
||||
|
||||
q_values = action_distribution.inputs
|
||||
if self.framework == "tf":
|
||||
if self.framework in ["tf", "tfe"]:
|
||||
return self._get_tf_exploration_action_op(q_values, explore,
|
||||
timestep)
|
||||
else:
|
||||
|
|
|
@ -290,10 +290,9 @@ class ParameterNoise(Exploration):
|
|||
def _sample_new_noise(self, *, tf_sess=None):
|
||||
"""Samples new noise and stores it in `self.noise`."""
|
||||
if self.framework == "tf":
|
||||
if tf.executing_eagerly():
|
||||
self._tf_sample_new_noise_op()
|
||||
else:
|
||||
tf_sess.run(self.tf_sample_new_noise_op)
|
||||
tf_sess.run(self.tf_sample_new_noise_op)
|
||||
elif self.framework == "tfe":
|
||||
self._tf_sample_new_noise_op()
|
||||
else:
|
||||
for i in range(len(self.noise)):
|
||||
self.noise[i] = torch.normal(
|
||||
|
@ -312,7 +311,7 @@ class ParameterNoise(Exploration):
|
|||
return tf.group(*added_noises)
|
||||
|
||||
def _sample_new_noise_and_add(self, *, tf_sess=None, override=False):
|
||||
if self.framework == "tf" and not tf.executing_eagerly():
|
||||
if self.framework == "tf":
|
||||
if override and self.weights_are_currently_noisy:
|
||||
tf_sess.run(self.tf_remove_noise_op)
|
||||
tf_sess.run(self.tf_sample_new_noise_and_add_op)
|
||||
|
@ -338,12 +337,11 @@ class ParameterNoise(Exploration):
|
|||
# Make sure we only add noise to currently noise-free weights.
|
||||
assert self.weights_are_currently_noisy is False
|
||||
|
||||
if self.framework == "tf":
|
||||
if tf.executing_eagerly():
|
||||
self._tf_add_stored_noise_op()
|
||||
else:
|
||||
tf_sess.run(self.tf_add_stored_noise_op)
|
||||
# Add stored noise to the model's parameters.
|
||||
if self.framework == "tf":
|
||||
tf_sess.run(self.tf_add_stored_noise_op)
|
||||
elif self.framework == "tfe":
|
||||
self._tf_add_stored_noise_op()
|
||||
else:
|
||||
for i in range(len(self.noise)):
|
||||
# Add noise to weights in-place.
|
||||
|
@ -377,13 +375,12 @@ class ParameterNoise(Exploration):
|
|||
# Make sure we only remove noise iff currently noisy.
|
||||
assert self.weights_are_currently_noisy is True
|
||||
|
||||
# Removes the stored noise from the model's parameters.
|
||||
if self.framework == "tf":
|
||||
if tf.executing_eagerly():
|
||||
self._tf_remove_noise_op()
|
||||
else:
|
||||
tf_sess.run(self.tf_remove_noise_op)
|
||||
tf_sess.run(self.tf_remove_noise_op)
|
||||
elif self.framework == "tfe":
|
||||
self._tf_remove_noise_op()
|
||||
else:
|
||||
# Removes the stored noise from the model's parameters.
|
||||
for var, noise in zip(self.model_variables, self.noise):
|
||||
# Remove noise from weights in-place.
|
||||
var.add_(-noise)
|
||||
|
|
|
@ -28,7 +28,7 @@ class Random(Exploration):
|
|||
|
||||
Args:
|
||||
action_space (Space): The gym action space used by the environment.
|
||||
framework (Optional[str]): One of None, "tf", "torch".
|
||||
framework (Optional[str]): One of None, "tf", "tfe", "torch".
|
||||
"""
|
||||
super().__init__(
|
||||
action_space=action_space,
|
||||
|
@ -46,7 +46,7 @@ class Random(Exploration):
|
|||
timestep: Union[int, TensorType],
|
||||
explore: bool = True):
|
||||
# Instantiate the distribution object.
|
||||
if self.framework == "tf":
|
||||
if self.framework in ["tf", "tfe"]:
|
||||
return self.get_tf_exploration_action_op(action_distribution,
|
||||
explore)
|
||||
else:
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from ray.rllib.utils.types import TensorStructType, TensorShape, TensorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Represents a generic tensor type.
|
||||
# TODO(ekl) this is duplicated in types.py
|
||||
TensorType = Any
|
||||
TensorType = TensorType
|
||||
|
||||
# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
|
||||
# TODO(ekl) this is duplicated in types.py
|
||||
TensorStructType = Union[TensorType, dict, tuple]
|
||||
TensorStructType = TensorStructType
|
||||
|
||||
|
||||
def try_import_tf(error=False):
|
||||
|
@ -39,6 +39,9 @@ def try_import_tf(error=False):
|
|||
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
# TODO: (sven) Allow env var to force compat.v1 behavior even if tf2.x
|
||||
# installed.
|
||||
|
||||
# Try to reuse already imported tf module. This will avoid going through
|
||||
# the initial import steps below and thereby switching off v2_behavior
|
||||
# (switching off v2 behavior twice breaks all-framework tests for eager).
|
||||
|
@ -160,15 +163,18 @@ def _torch_stubs():
|
|||
|
||||
|
||||
def get_variable(value,
|
||||
framework="tf",
|
||||
trainable=False,
|
||||
tf_name="unnamed-variable",
|
||||
torch_tensor=False,
|
||||
device=None):
|
||||
framework: str = "tf",
|
||||
trainable: bool = False,
|
||||
tf_name: str = "unnamed-variable",
|
||||
torch_tensor: bool = False,
|
||||
device: Optional[str] = None,
|
||||
shape: Optional[TensorShape] = None,
|
||||
dtype: Optional[Any] = None):
|
||||
"""
|
||||
Args:
|
||||
value (any): The initial value to use. In the non-tf case, this will
|
||||
be returned as is.
|
||||
be returned as is. In the tf case, this could be a tf-Initializer
|
||||
object.
|
||||
framework (str): One of "tf", "torch", or None.
|
||||
trainable (bool): Whether the generated variable should be
|
||||
trainable (tf)/require_grad (torch) or not (default: False).
|
||||
|
@ -176,19 +182,27 @@ def get_variable(value,
|
|||
tf.Variable.
|
||||
torch_tensor (bool): For framework="torch": Whether to actually create
|
||||
a torch.tensor, or just a python value (default).
|
||||
device (Optional[torch.Device]): An optional torch device to use for
|
||||
the created torch tensor.
|
||||
shape (Optional[TensorShape]): An optional shape to use iff `value`
|
||||
does not have any (e.g. if it's an initializer w/o explicit value).
|
||||
dtype (Optional[TensorType]): An optional dtype to use iff `value` does
|
||||
not have any (e.g. if it's an initializer w/o explicit value).
|
||||
|
||||
Returns:
|
||||
any: A framework-specific variable (tf.Variable, torch.tensor, or
|
||||
python primitive).
|
||||
"""
|
||||
if framework == "tf":
|
||||
if framework in ["tf", "tfe"]:
|
||||
import tensorflow as tf
|
||||
dtype = getattr(
|
||||
dtype = dtype or getattr(
|
||||
value, "dtype", tf.float32
|
||||
if isinstance(value, float) else tf.int32
|
||||
if isinstance(value, int) else None)
|
||||
return tf.compat.v1.get_variable(
|
||||
tf_name, initializer=value, dtype=dtype, trainable=trainable)
|
||||
tf_name, initializer=value, dtype=dtype, trainable=trainable,
|
||||
**({} if shape is None else {"shape": shape})
|
||||
)
|
||||
elif framework == "torch" and torch_tensor is True:
|
||||
torch, _ = try_import_torch()
|
||||
var_ = torch.from_numpy(value)
|
||||
|
|
|
@ -35,7 +35,7 @@ class Schedule(metaclass=ABCMeta):
|
|||
Returns:
|
||||
any: The calculated value depending on the schedule and `t`.
|
||||
"""
|
||||
if self.framework == "tf" and not tf.executing_eagerly():
|
||||
if self.framework in ["tf", "tfe"]:
|
||||
return self._tf_value_op(t)
|
||||
return self._value(t)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue