From 01125b8fcfce88472c2e20d116bc6662032088e0 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 9 Jul 2020 10:44:10 +0200 Subject: [PATCH] [RLlib] DQN rainbow eager-mode (keras style NoisyLayer) (preparation for native tf2.x support). (#9304) --- rllib/BUILD | 2 +- rllib/agents/dqn/distributional_q_tf_model.py | 138 +++--------------- rllib/models/modelv2.py | 4 +- rllib/models/tf/layers/gru_gate.py | 2 +- .../models/tf/layers/multi_head_attention.py | 2 +- rllib/models/tf/layers/noisy_layer.py | 2 +- .../layers/relative_multi_head_attention.py | 2 +- rllib/models/tf/layers/skip_connection.py | 2 +- rllib/policy/eager_tf_policy.py | 20 ++- rllib/utils/exploration/epsilon_greedy.py | 4 +- rllib/utils/exploration/parameter_noise.py | 27 ++-- rllib/utils/exploration/random.py | 4 +- rllib/utils/framework.py | 42 ++++-- rllib/utils/schedules/schedule.py | 2 +- 14 files changed, 87 insertions(+), 166 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 1585434d1..fae2838fb 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -425,7 +425,7 @@ py_test( py_test( name = "test_dqn", tags = ["agents_dir"], - size = "medium", + size = "large", srcs = ["agents/dqn/tests/test_dqn.py"] ) py_test( diff --git a/rllib/agents/dqn/distributional_q_tf_model.py b/rllib/agents/dqn/distributional_q_tf_model.py index cc30e2d9d..b04ab191f 100644 --- a/rllib/agents/dqn/distributional_q_tf_model.py +++ b/rllib/agents/dqn/distributional_q_tf_model.py @@ -1,5 +1,4 @@ -import numpy as np - +from ray.rllib.models.tf.layers import NoisyLayer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.framework import try_import_tf @@ -69,13 +68,15 @@ class DistributionalQTFModel(TFModelV2): self.model_out = tf.keras.layers.Input( shape=(num_outputs, ), name="model_out") - def build_action_value(model_out): + def build_action_value(prefix, model_out): if q_hiddens: action_out = model_out for i in range(len(q_hiddens)): if use_noisy: - action_out = self._noisy_layer( - "hidden_%d" % i, action_out, q_hiddens[i], sigma0) + action_out = NoisyLayer( + "{}hidden_{}".format(prefix, i), + q_hiddens[i], + sigma0)(action_out) elif add_layer_norm: action_out = tf.keras.layers.Dense( units=q_hiddens[i], @@ -94,12 +95,11 @@ class DistributionalQTFModel(TFModelV2): action_out = model_out if use_noisy: - action_scores = self._noisy_layer( - "output", - action_out, + action_scores = NoisyLayer( + "{}output".format(prefix), self.action_space.n * num_atoms, sigma0, - non_linear=False) + activation=None)(action_out) elif q_hiddens: action_scores = tf.keras.layers.Dense( units=self.action_space.n * num_atoms, @@ -130,13 +130,14 @@ class DistributionalQTFModel(TFModelV2): dist = tf.expand_dims(tf.ones_like(action_scores), -1) return [action_scores, logits, dist] - def build_state_score(model_out): + def build_state_score(prefix, model_out): state_out = model_out for i in range(len(q_hiddens)): if use_noisy: - state_out = self._noisy_layer("dueling_hidden_%d" % i, - state_out, q_hiddens[i], - sigma0) + state_out = NoisyLayer( + "{}dueling_hidden_{}".format(prefix, i), + q_hiddens[i], + sigma0)(state_out) else: state_out = tf.keras.layers.Dense( units=q_hiddens[i], activation=tf.nn.relu)(state_out) @@ -144,59 +145,23 @@ class DistributionalQTFModel(TFModelV2): state_out = tf.keras.layers.LayerNormalization()( state_out) if use_noisy: - state_score = self._noisy_layer( - "dueling_output", - state_out, + state_score = NoisyLayer( + "{}dueling_output".format(prefix), num_atoms, sigma0, - non_linear=False) + activation=None)(state_out) else: state_score = tf.keras.layers.Dense( units=num_atoms, activation=None)(state_out) return state_score - if tf1.executing_eagerly(): - from tensorflow.python.ops import variable_scope - # Have to use a variable store to reuse variables in eager mode - store = variable_scope.EagerVariableStore() - - # Save the scope objects, since in eager we will execute this - # path repeatedly and there is no guarantee it will always be run - # in the same original scope. - with tf1.variable_scope(name + "/action_value") as action_scope: - pass - with tf1.variable_scope(name + "/state_value") as state_scope: - pass - - def build_action_value_in_scope(model_out): - with store.as_default(): - with tf1.variable_scope( - action_scope, reuse=tf1.AUTO_REUSE): - return build_action_value(model_out) - - def build_state_score_in_scope(model_out): - with store.as_default(): - with tf1.variable_scope( - state_scope, reuse=tf1.AUTO_REUSE): - return build_state_score(model_out) - else: - - def build_action_value_in_scope(model_out): - with tf1.variable_scope( - name + "/action_value", reuse=tf1.AUTO_REUSE): - return build_action_value(model_out) - - def build_state_score_in_scope(model_out): - with tf1.variable_scope( - name + "/state_value", reuse=tf1.AUTO_REUSE): - return build_state_score(model_out) - - q_out = build_action_value_in_scope(self.model_out) + q_out = build_action_value(name + "/action_value/", self.model_out) self.q_value_head = tf.keras.Model(self.model_out, q_out) self.register_variables(self.q_value_head.variables) if dueling: - state_out = build_state_score_in_scope(self.model_out) + state_out = build_state_score( + name + "/state_value/", self.model_out) self.state_value_head = tf.keras.Model(self.model_out, state_out) self.register_variables(self.state_value_head.variables) @@ -219,66 +184,3 @@ class DistributionalQTFModel(TFModelV2): """Returns the state value prediction for the given state embedding.""" return self.state_value_head(model_out) - - def _noisy_layer(self, - prefix, - action_in, - out_size, - sigma0, - non_linear=True): - """ - a common dense layer: y = w^{T}x + b - a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x + - (b+\\epsilon_b*\\sigma_b) - where \epsilon are random variables sampled from factorized normal - distributions and \\sigma are trainable variables which are expected to - vanish along the training procedure - """ - in_size = int(action_in.shape[1]) - - epsilon_in = tf.random.normal(shape=[in_size]) - epsilon_out = tf.random.normal(shape=[out_size]) - epsilon_in = self._f_epsilon(epsilon_in) - epsilon_out = self._f_epsilon(epsilon_out) - epsilon_w = tf.matmul( - a=tf.expand_dims(epsilon_in, -1), b=tf.expand_dims(epsilon_out, 0)) - epsilon_b = epsilon_out - sigma_w = tf1.get_variable( - name=prefix + "_sigma_w", - shape=[in_size, out_size], - dtype=tf.float32, - initializer=tf1.random_uniform_initializer( - minval=-1.0 / np.sqrt(float(in_size)), - maxval=1.0 / np.sqrt(float(in_size)))) - # TF noise generation can be unreliable on GPU - # If generating the noise on the CPU, - # lowering sigma0 to 0.1 may be helpful - sigma_b = tf1.get_variable( - name=prefix + "_sigma_b", - shape=[out_size], - dtype=tf.float32, # 0.5~GPU, 0.1~CPU - initializer=tf1.constant_initializer( - sigma0 / np.sqrt(float(in_size)))) - - w = tf1.get_variable( - name=prefix + "_fc_w", - shape=[in_size, out_size], - dtype=tf.float32, - initializer=tf.initializers.GlorotUniform()) - b = tf1.get_variable( - name=prefix + "_fc_b", - shape=[out_size], - dtype=tf.float32, - initializer=tf.initializers.Zeros()) - - action_activation = \ - tf.keras.layers.Lambda(lambda x: tf.matmul( - x, w + sigma_w * epsilon_w) + b + sigma_b * epsilon_b)( - action_in) - - if not non_linear: - return action_activation - return tf.nn.relu(action_activation) - - def _f_epsilon(self, x): - return tf.math.sign(x) * tf.math.sqrt(tf.math.abs(x)) diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 7247d119f..7a92c9eb9 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -338,8 +338,8 @@ class NullContextManager: @DeveloperAPI def flatten(obs, framework): """Flatten the given tensor.""" - if framework == "tf": - return tf1.layers.flatten(obs) + if framework in ["tf", "tfe"]: + return tf1.keras.layers.Flatten()(obs) elif framework == "torch": assert torch is not None return torch.flatten(obs, start_dim=1) diff --git a/rllib/models/tf/layers/gru_gate.py b/rllib/models/tf/layers/gru_gate.py index 69dba748c..89dbf652b 100644 --- a/rllib/models/tf/layers/gru_gate.py +++ b/rllib/models/tf/layers/gru_gate.py @@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_tf tf1, tf, tfv = try_import_tf() -class GRUGate(tf.keras.layers.Layer): +class GRUGate(tf.keras.layers.Layer if tf else object): def __init__(self, init_bias=0., **kwargs): super().__init__(**kwargs) self._init_bias = init_bias diff --git a/rllib/models/tf/layers/multi_head_attention.py b/rllib/models/tf/layers/multi_head_attention.py index 04583adaa..0971f186f 100644 --- a/rllib/models/tf/layers/multi_head_attention.py +++ b/rllib/models/tf/layers/multi_head_attention.py @@ -8,7 +8,7 @@ from ray.rllib.utils.framework import try_import_tf tf1, tf, tfv = try_import_tf() -class MultiHeadAttention(tf.keras.layers.Layer): +class MultiHeadAttention(tf.keras.layers.Layer if tf else object): """A multi-head attention layer described in [1].""" def __init__(self, out_dim, num_heads, head_dim, **kwargs): diff --git a/rllib/models/tf/layers/noisy_layer.py b/rllib/models/tf/layers/noisy_layer.py index a204bd222..9fa570db5 100644 --- a/rllib/models/tf/layers/noisy_layer.py +++ b/rllib/models/tf/layers/noisy_layer.py @@ -6,7 +6,7 @@ from ray.rllib.utils.framework import get_activation_fn, get_variable, \ tf1, tf, tfv = try_import_tf() -class NoisyLayer(tf.keras.layers.Layer): +class NoisyLayer(tf.keras.layers.Layer if tf else object): """A Layer that adds learnable Noise a common dense layer: y = w^{T}x + b a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x + diff --git a/rllib/models/tf/layers/relative_multi_head_attention.py b/rllib/models/tf/layers/relative_multi_head_attention.py index affd48cee..bd52c0bf7 100644 --- a/rllib/models/tf/layers/relative_multi_head_attention.py +++ b/rllib/models/tf/layers/relative_multi_head_attention.py @@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_tf tf1, tf, tfv = try_import_tf() -class RelativeMultiHeadAttention(tf.keras.layers.Layer): +class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): """A RelativeMultiHeadAttention layer as described in [3]. Uses segment level recurrence with state reuse. diff --git a/rllib/models/tf/layers/skip_connection.py b/rllib/models/tf/layers/skip_connection.py index f2f0e1d5f..9d6b766e4 100644 --- a/rllib/models/tf/layers/skip_connection.py +++ b/rllib/models/tf/layers/skip_connection.py @@ -3,7 +3,7 @@ from ray.rllib.utils.framework import try_import_tf tf1, tf, tfv = try_import_tf() -class SkipConnection(tf.keras.layers.Layer): +class SkipConnection(tf.keras.layers.Layer if tf else object): """Skip connection layer. Adds the original input to the output (regular residual layer) OR uses diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 13d9daf8f..3ef4a1aa0 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -200,7 +200,7 @@ def build_eager_tf_policy(name, class eager_policy_cls(base): def __init__(self, observation_space, action_space, config): assert tf.executing_eagerly() - self.framework = "tf" + self.framework = "tfe" Policy.__init__(self, observation_space, action_space, config) self._is_training = False self._loss_initialized = False @@ -235,7 +235,7 @@ def build_eager_tf_policy(name, action_space, logit_dim, config["model"], - framework="tf", + framework=self.framework, ) self.exploration = self._create_exploration() self._state_in = [ @@ -352,7 +352,8 @@ def build_eager_tf_policy(name, self.model, input_dict[SampleBatch.CUR_OBS], explore=explore, - timestep=timestep) + timestep=timestep, + episodes=episodes) else: # Exploration hook before each forward pass. self.exploration.before_compute_actions( @@ -457,8 +458,10 @@ def build_eager_tf_policy(name, return _convert_to_numpy(self.exploration.get_info()) @override(Policy) - def get_weights(self): + def get_weights(self, as_dict=False): variables = self.variables() + if as_dict: + return {v.name: v.numpy() for v in variables} return [v.numpy() for v in variables] @override(Policy) @@ -638,8 +641,8 @@ def build_eager_tf_policy(name, dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) # Convert everything to tensors. - dummy_batch = tf.nest.map_structure(tf1.convert_to_tensor, - dummy_batch) + dummy_batch = tf.nest.map_structure( + tf1.convert_to_tensor, dummy_batch) # for IMPALA which expects a certain sample batch size. def tile_to(tensor, n): @@ -650,6 +653,11 @@ def build_eager_tf_policy(name, dummy_batch = tf.nest.map_structure( lambda c: tile_to(c, get_batch_divisibility_req(self)), dummy_batch) + i = 0 + self._state_in = [] + while "state_in_{}".format(i) in dummy_batch: + self._state_in.append(dummy_batch["state_in_{}".format(i)]) + i += 1 # Execute a forward pass to get self.action_dist etc initialized, # and also obtain the extra action fetches diff --git a/rllib/utils/exploration/epsilon_greedy.py b/rllib/utils/exploration/epsilon_greedy.py index 75b17215e..77c13ea67 100644 --- a/rllib/utils/exploration/epsilon_greedy.py +++ b/rllib/utils/exploration/epsilon_greedy.py @@ -57,7 +57,7 @@ class EpsilonGreedy(Exploration): 0, framework=framework, tf_name="timestep") # Build the tf-info-op. - if self.framework == "tf": + if self.framework in ["tf", "tfe"]: self._tf_info_op = self.get_info() @override(Exploration) @@ -68,7 +68,7 @@ class EpsilonGreedy(Exploration): explore: bool = True): q_values = action_distribution.inputs - if self.framework == "tf": + if self.framework in ["tf", "tfe"]: return self._get_tf_exploration_action_op(q_values, explore, timestep) else: diff --git a/rllib/utils/exploration/parameter_noise.py b/rllib/utils/exploration/parameter_noise.py index abf59f188..dcd1c564b 100644 --- a/rllib/utils/exploration/parameter_noise.py +++ b/rllib/utils/exploration/parameter_noise.py @@ -290,10 +290,9 @@ class ParameterNoise(Exploration): def _sample_new_noise(self, *, tf_sess=None): """Samples new noise and stores it in `self.noise`.""" if self.framework == "tf": - if tf.executing_eagerly(): - self._tf_sample_new_noise_op() - else: - tf_sess.run(self.tf_sample_new_noise_op) + tf_sess.run(self.tf_sample_new_noise_op) + elif self.framework == "tfe": + self._tf_sample_new_noise_op() else: for i in range(len(self.noise)): self.noise[i] = torch.normal( @@ -312,7 +311,7 @@ class ParameterNoise(Exploration): return tf.group(*added_noises) def _sample_new_noise_and_add(self, *, tf_sess=None, override=False): - if self.framework == "tf" and not tf.executing_eagerly(): + if self.framework == "tf": if override and self.weights_are_currently_noisy: tf_sess.run(self.tf_remove_noise_op) tf_sess.run(self.tf_sample_new_noise_and_add_op) @@ -338,12 +337,11 @@ class ParameterNoise(Exploration): # Make sure we only add noise to currently noise-free weights. assert self.weights_are_currently_noisy is False - if self.framework == "tf": - if tf.executing_eagerly(): - self._tf_add_stored_noise_op() - else: - tf_sess.run(self.tf_add_stored_noise_op) # Add stored noise to the model's parameters. + if self.framework == "tf": + tf_sess.run(self.tf_add_stored_noise_op) + elif self.framework == "tfe": + self._tf_add_stored_noise_op() else: for i in range(len(self.noise)): # Add noise to weights in-place. @@ -377,13 +375,12 @@ class ParameterNoise(Exploration): # Make sure we only remove noise iff currently noisy. assert self.weights_are_currently_noisy is True + # Removes the stored noise from the model's parameters. if self.framework == "tf": - if tf.executing_eagerly(): - self._tf_remove_noise_op() - else: - tf_sess.run(self.tf_remove_noise_op) + tf_sess.run(self.tf_remove_noise_op) + elif self.framework == "tfe": + self._tf_remove_noise_op() else: - # Removes the stored noise from the model's parameters. for var, noise in zip(self.model_variables, self.noise): # Remove noise from weights in-place. var.add_(-noise) diff --git a/rllib/utils/exploration/random.py b/rllib/utils/exploration/random.py index 935848bbf..7664a46b2 100644 --- a/rllib/utils/exploration/random.py +++ b/rllib/utils/exploration/random.py @@ -28,7 +28,7 @@ class Random(Exploration): Args: action_space (Space): The gym action space used by the environment. - framework (Optional[str]): One of None, "tf", "torch". + framework (Optional[str]): One of None, "tf", "tfe", "torch". """ super().__init__( action_space=action_space, @@ -46,7 +46,7 @@ class Random(Exploration): timestep: Union[int, TensorType], explore: bool = True): # Instantiate the distribution object. - if self.framework == "tf": + if self.framework in ["tf", "tfe"]: return self.get_tf_exploration_action_op(action_distribution, explore) else: diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index ac032d2f0..6a140c1a7 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -1,17 +1,17 @@ import logging import os import sys -from typing import Any, Union +from typing import Any, Optional + +from ray.rllib.utils.types import TensorStructType, TensorShape, TensorType logger = logging.getLogger(__name__) # Represents a generic tensor type. -# TODO(ekl) this is duplicated in types.py -TensorType = Any +TensorType = TensorType # Either a plain tensor, or a dict or tuple of tensors (or StructTensors). -# TODO(ekl) this is duplicated in types.py -TensorStructType = Union[TensorType, dict, tuple] +TensorStructType = TensorStructType def try_import_tf(error=False): @@ -39,6 +39,9 @@ def try_import_tf(error=False): if "TF_CPP_MIN_LOG_LEVEL" not in os.environ: os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + # TODO: (sven) Allow env var to force compat.v1 behavior even if tf2.x + # installed. + # Try to reuse already imported tf module. This will avoid going through # the initial import steps below and thereby switching off v2_behavior # (switching off v2 behavior twice breaks all-framework tests for eager). @@ -160,15 +163,18 @@ def _torch_stubs(): def get_variable(value, - framework="tf", - trainable=False, - tf_name="unnamed-variable", - torch_tensor=False, - device=None): + framework: str = "tf", + trainable: bool = False, + tf_name: str = "unnamed-variable", + torch_tensor: bool = False, + device: Optional[str] = None, + shape: Optional[TensorShape] = None, + dtype: Optional[Any] = None): """ Args: value (any): The initial value to use. In the non-tf case, this will - be returned as is. + be returned as is. In the tf case, this could be a tf-Initializer + object. framework (str): One of "tf", "torch", or None. trainable (bool): Whether the generated variable should be trainable (tf)/require_grad (torch) or not (default: False). @@ -176,19 +182,27 @@ def get_variable(value, tf.Variable. torch_tensor (bool): For framework="torch": Whether to actually create a torch.tensor, or just a python value (default). + device (Optional[torch.Device]): An optional torch device to use for + the created torch tensor. + shape (Optional[TensorShape]): An optional shape to use iff `value` + does not have any (e.g. if it's an initializer w/o explicit value). + dtype (Optional[TensorType]): An optional dtype to use iff `value` does + not have any (e.g. if it's an initializer w/o explicit value). Returns: any: A framework-specific variable (tf.Variable, torch.tensor, or python primitive). """ - if framework == "tf": + if framework in ["tf", "tfe"]: import tensorflow as tf - dtype = getattr( + dtype = dtype or getattr( value, "dtype", tf.float32 if isinstance(value, float) else tf.int32 if isinstance(value, int) else None) return tf.compat.v1.get_variable( - tf_name, initializer=value, dtype=dtype, trainable=trainable) + tf_name, initializer=value, dtype=dtype, trainable=trainable, + **({} if shape is None else {"shape": shape}) + ) elif framework == "torch" and torch_tensor is True: torch, _ = try_import_torch() var_ = torch.from_numpy(value) diff --git a/rllib/utils/schedules/schedule.py b/rllib/utils/schedules/schedule.py index 316f359fe..ca21bd595 100644 --- a/rllib/utils/schedules/schedule.py +++ b/rllib/utils/schedules/schedule.py @@ -35,7 +35,7 @@ class Schedule(metaclass=ABCMeta): Returns: any: The calculated value depending on the schedule and `t`. """ - if self.framework == "tf" and not tf.executing_eagerly(): + if self.framework in ["tf", "tfe"]: return self._tf_value_op(t) return self._value(t)