ray/rllib/algorithms/dqn/distributional_q_tf_model.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

189 lines
7.8 KiB
Python
Raw Normal View History

"""Tensorflow model for DQN"""
from typing import List
import gym
from ray.rllib.models.tf.layers import NoisyLayer
2019-07-03 15:59:47 -07:00
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
2019-07-03 15:59:47 -07:00
tf1, tf, tfv = try_import_tf()
2019-07-03 15:59:47 -07:00
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
class DistributionalQTFModel(TFModelV2):
2019-07-03 15:59:47 -07:00
"""Extension of standard TFModel to provide distributional Q values.
It also supports options for noisy nets and parameter space noise.
Data flow:
obs -> forward() -> model_out
model_out -> get_q_value_distributions() -> Q(s, a) atoms
model_out -> get_state_value() -> V(s)
2019-07-03 15:59:47 -07:00
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
q_hiddens=(256,),
dueling: bool = False,
num_atoms: int = 1,
use_noisy: bool = False,
v_min: float = -10.0,
v_max: float = 10.0,
sigma0: float = 0.5,
# TODO(sven): Move `add_layer_norm` into ModelCatalog as
# generic option, then error if we use ParameterNoise as
# Exploration type and do not have any LayerNorm layers in
# the net.
add_layer_norm: bool = False,
):
2019-07-03 15:59:47 -07:00
"""Initialize variables of this model.
Extra model kwargs:
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
q_hiddens (List[int]): List of layer-sizes after(!) the
Advantages(A)/Value(V)-split. Hence, each of the A- and V-
branches will have this structure of Dense layers. To define
the NN before this A/V-split, use - as always -
config["model"]["fcnet_hiddens"].
dueling: Whether to build the advantage(A)/value(V) heads
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
for DDQN. If True, Q-values are calculated as:
Q = (A - mean[A]) + V. If False, raw NN output is interpreted
as Q-values.
num_atoms: If >1, enables distributional DQN.
use_noisy: Use noisy nets.
v_min: Min value support for distributional DQN.
v_max: Max value support for distributional DQN.
sigma0 (float): Initial value of noisy layers.
add_layer_norm: Enable layer norm (for param noise).
2019-07-03 15:59:47 -07:00
Note that the core layers for forward() are not defined here, this
only defines the layers for the Q head. Those layers for forward()
should be defined in subclasses of DistributionalQModel.
"""
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
super(DistributionalQTFModel, self).__init__(
2019-07-03 15:59:47 -07:00
obs_space, action_space, num_outputs, model_config, name
)
# setup the Q head output (i.e., model for get_q_values)
self.model_out = tf.keras.layers.Input(shape=(num_outputs,), name="model_out")
def build_action_value(prefix: str, model_out: TensorType) -> List[TensorType]:
2019-07-03 15:59:47 -07:00
if q_hiddens:
action_out = model_out
for i in range(len(q_hiddens)):
if use_noisy:
action_out = NoisyLayer(
"{}hidden_{}".format(prefix, i), q_hiddens[i], sigma0
)(action_out)
elif add_layer_norm:
action_out = tf.keras.layers.Dense(
units=q_hiddens[i], activation=tf.nn.relu
)(action_out)
action_out = tf.keras.layers.LayerNormalization()(action_out)
2019-07-03 15:59:47 -07:00
else:
action_out = tf.keras.layers.Dense(
2019-07-03 15:59:47 -07:00
units=q_hiddens[i],
activation=tf.nn.relu,
name="hidden_%d" % i,
)(action_out)
2019-07-03 15:59:47 -07:00
else:
# Avoid postprocessing the outputs. This enables custom models
# to be used for parametric action DQN.
action_out = model_out
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
2019-07-03 15:59:47 -07:00
if use_noisy:
action_scores = NoisyLayer(
"{}output".format(prefix),
2019-07-03 15:59:47 -07:00
self.action_space.n * num_atoms,
sigma0,
activation=None,
)(action_out)
2019-07-03 15:59:47 -07:00
elif q_hiddens:
action_scores = tf.keras.layers.Dense(
2019-07-03 15:59:47 -07:00
units=self.action_space.n * num_atoms, activation=None
)(action_out)
2019-07-03 15:59:47 -07:00
else:
action_scores = model_out
2019-07-03 15:59:47 -07:00
if num_atoms > 1:
# Distributional Q-learning uses a discrete support z
# to represent the action value distribution
z = tf.range(num_atoms, dtype=tf.float32)
z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
def _layer(x):
support_logits_per_action = tf.reshape(
tensor=x, shape=(-1, self.action_space.n, num_atoms)
)
support_prob_per_action = tf.nn.softmax(
logits=support_logits_per_action
)
x = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1)
logits = support_logits_per_action
dist = support_prob_per_action
return [x, z, support_logits_per_action, logits, dist]
return tf.keras.layers.Lambda(_layer)(action_scores)
2019-07-03 15:59:47 -07:00
else:
logits = tf.expand_dims(tf.ones_like(action_scores), -1)
dist = tf.expand_dims(tf.ones_like(action_scores), -1)
return [action_scores, logits, dist]
def build_state_score(prefix: str, model_out: TensorType) -> TensorType:
2019-07-03 15:59:47 -07:00
state_out = model_out
for i in range(len(q_hiddens)):
if use_noisy:
state_out = NoisyLayer(
"{}dueling_hidden_{}".format(prefix, i), q_hiddens[i], sigma0
)(state_out)
2019-07-03 15:59:47 -07:00
else:
state_out = tf.keras.layers.Dense(
units=q_hiddens[i], activation=tf.nn.relu
)(state_out)
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
if add_layer_norm:
state_out = tf.keras.layers.LayerNormalization()(state_out)
2019-07-03 15:59:47 -07:00
if use_noisy:
state_score = NoisyLayer(
"{}dueling_output".format(prefix),
2019-07-03 15:59:47 -07:00
num_atoms,
sigma0,
activation=None,
)(state_out)
2019-07-03 15:59:47 -07:00
else:
state_score = tf.keras.layers.Dense(units=num_atoms, activation=None)(
state_out
)
2019-07-03 15:59:47 -07:00
return state_score
q_out = build_action_value(name + "/action_value/", self.model_out)
2019-07-03 15:59:47 -07:00
self.q_value_head = tf.keras.Model(self.model_out, q_out)
if dueling:
state_out = build_state_score(name + "/state_value/", self.model_out)
2019-07-03 15:59:47 -07:00
self.state_value_head = tf.keras.Model(self.model_out, state_out)
def get_q_value_distributions(self, model_out: TensorType) -> List[TensorType]:
2019-07-03 15:59:47 -07:00
"""Returns distributional values for Q(s, a) given a state embedding.
Override this in your custom model to customize the Q output head.
2020-09-20 11:27:02 +02:00
Args:
model_out: embedding from the model layers
2019-07-03 15:59:47 -07:00
Returns:
(action_scores, logits, dist) if num_atoms == 1, otherwise
(action_scores, z, support_logits_per_action, logits, dist)
"""
return self.q_value_head(model_out)
def get_state_value(self, model_out: TensorType) -> TensorType:
2019-07-03 15:59:47 -07:00
"""Returns the state value prediction for the given state embedding."""
return self.state_value_head(model_out)