ray/rllib/agents/dqn/dqn_tf_policy.py

454 lines
18 KiB
Python
Raw Normal View History

2021-02-25 12:18:11 +01:00
"""TensorFlow policy class used for DQN"""
from typing import Dict
import gym
import numpy as np
import ray
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
from ray.rllib.agents.dqn.distributional_q_tf_model import \
DistributionalQTFModel
from ray.rllib.agents.dqn.simple_q_tf_policy import TargetNetworkMixin
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration import ParameterNoise
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.tf_ops import (huber_loss, make_tf_callable,
minimize_and_clip, reduce_mean_ignore_inf)
from ray.rllib.utils.typing import (ModelGradients, TensorType,
TrainerConfigDict)
tf1, tf, tfv = try_import_tf()
Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"
# Importance sampling weights for prioritized replay
PRIO_WEIGHTS = "weights"
class QLoss:
def __init__(self,
q_t_selected: TensorType,
q_logits_t_selected: TensorType,
q_tp1_best: TensorType,
q_dist_tp1_best: TensorType,
importance_weights: TensorType,
rewards: TensorType,
done_mask: TensorType,
gamma: float = 0.99,
n_step: int = 1,
num_atoms: int = 1,
v_min: float = -10.0,
v_max: float = 10.0):
if num_atoms > 1:
# Distributional Q-learning which corresponds to an entropy loss
z = tf.range(num_atoms, dtype=tf.float32)
z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
# (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
r_tau = tf.expand_dims(
rewards, -1) + gamma**n_step * tf.expand_dims(
1.0 - done_mask, -1) * tf.expand_dims(z, 0)
r_tau = tf.clip_by_value(r_tau, v_min, v_max)
b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
lb = tf.floor(b)
ub = tf.math.ceil(b)
# indispensable judgement which is missed in most implementations
# when b happens to be an integer, lb == ub, so pr_j(s', a*) will
# be discarded because (ub-b) == (b-lb) == 0
floor_equal_ceil = tf.cast(tf.less(ub - lb, 0.5), tf.float32)
l_project = tf.one_hot(
tf.cast(lb, dtype=tf.int32),
num_atoms) # (batch_size, num_atoms, num_atoms)
u_project = tf.one_hot(
tf.cast(ub, dtype=tf.int32),
num_atoms) # (batch_size, num_atoms, num_atoms)
ml_delta = q_dist_tp1_best * (ub - b + floor_equal_ceil)
mu_delta = q_dist_tp1_best * (b - lb)
ml_delta = tf.reduce_sum(
l_project * tf.expand_dims(ml_delta, -1), axis=1)
mu_delta = tf.reduce_sum(
u_project * tf.expand_dims(mu_delta, -1), axis=1)
m = ml_delta + mu_delta
# Rainbow paper claims that using this cross entropy loss for
# priority is robust and insensitive to `prioritized_replay_alpha`
self.td_error = tf.nn.softmax_cross_entropy_with_logits(
labels=m, logits=q_logits_t_selected)
self.loss = tf.reduce_mean(
self.td_error * tf.cast(importance_weights, tf.float32))
self.stats = {
# TODO: better Q stats for dist dqn
"mean_td_error": tf.reduce_mean(self.td_error),
}
else:
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
# compute the error (potentially clipped)
self.td_error = (
q_t_selected - tf.stop_gradient(q_t_selected_target))
self.loss = tf.reduce_mean(
tf.cast(importance_weights, tf.float32) * huber_loss(
self.td_error))
self.stats = {
"mean_q": tf.reduce_mean(q_t_selected),
"min_q": tf.reduce_min(q_t_selected),
"max_q": tf.reduce_max(q_t_selected),
"mean_td_error": tf.reduce_mean(self.td_error),
}
class ComputeTDErrorMixin:
"""Assign the `compute_td_error` method to the DQNTFPolicy
This allows us to prioritize on the worker side.
"""
def __init__(self):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
# Do forward pass on loss to update td error attribute
build_q_losses(
self, self.model, None, {
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
})
return self.q_loss.td_error
self.compute_td_error = compute_td_error
def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> ModelV2:
"""Build q_model and target_model for DQN
Args:
policy (Policy): The Policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
Returns:
ModelV2: The Model for the Policy to use.
Note: The target q model will not be returned, just assigned to
`policy.target_model`.
"""
if not isinstance(action_space, gym.spaces.Discrete):
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(action_space))
2019-07-03 15:59:47 -07:00
if config["hiddens"]:
# try to infer the last layer size, otherwise fall back to 256
num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
2019-07-03 15:59:47 -07:00
config["model"]["no_final_linear"] = True
else:
num_outputs = action_space.n
q_model = ModelCatalog.get_model_v2(
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=config["model"],
2019-07-03 15:59:47 -07:00
framework="tf",
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
model_interface=DistributionalQTFModel,
2019-07-03 15:59:47 -07:00
name=Q_SCOPE,
num_atoms=config["num_atoms"],
dueling=config["dueling"],
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
q_hiddens=config["hiddens"],
2019-07-03 15:59:47 -07:00
use_noisy=config["noisy"],
v_min=config["v_min"],
v_max=config["v_max"],
sigma0=config["sigma0"],
# TODO(sven): Move option to add LayerNorm after each Dense
# generically into ModelCatalog.
add_layer_norm=isinstance(
getattr(policy, "exploration", None), ParameterNoise)
or config["exploration_config"]["type"] == "ParameterNoise")
2019-07-03 15:59:47 -07:00
policy.target_model = ModelCatalog.get_model_v2(
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=config["model"],
2019-07-03 15:59:47 -07:00
framework="tf",
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
model_interface=DistributionalQTFModel,
2019-07-03 15:59:47 -07:00
name=Q_TARGET_SCOPE,
num_atoms=config["num_atoms"],
dueling=config["dueling"],
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
q_hiddens=config["hiddens"],
2019-07-03 15:59:47 -07:00
use_noisy=config["noisy"],
v_min=config["v_min"],
v_max=config["v_max"],
sigma0=config["sigma0"],
# TODO(sven): Move option to add LayerNorm after each Dense
# generically into ModelCatalog.
add_layer_norm=isinstance(
getattr(policy, "exploration", None), ParameterNoise)
or config["exploration_config"]["type"] == "ParameterNoise")
2019-07-03 15:59:47 -07:00
return q_model
2019-07-03 15:59:47 -07:00
def get_distribution_inputs_and_class(policy: Policy,
model: ModelV2,
obs_batch: TensorType,
*,
explore=True,
**kwargs):
2021-02-25 12:18:11 +01:00
q_vals = compute_q_values(
policy, model, {"obs": obs_batch}, state_batches=None, explore=explore)
[RLlib] Policy.compute_log_likelihoods() and SAC refactor. (issue #7107) (#7124) * Exploration API (+EpsilonGreedy sub-class). * Exploration API (+EpsilonGreedy sub-class). * Cleanup/LINT. * Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents). * Add `error` option to deprecation_warning(). * WIP. * Bug fix: Get exploration-info for tf framework. Bug fix: Properly deprecate some DQN config keys. * WIP. * LINT. * WIP. * Split PerWorkerEpsilonGreedy out of EpsilonGreedy. Docstrings. * Fix bug in sampler.py in case Policy has self.exploration = None * Update rllib/agents/dqn/dqn.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Update rllib/agents/trainer.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Change requests. * LINT * In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set * Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps). * Update rllib/evaluation/worker_set.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Review fixes. * Fix default value for DQN's exploration spec. * LINT * Fix recursion bug (wrong parent c'tor). * Do not pass timestep to get_exploration_info. * Update tf_policy.py * Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs. * Bug fix tf-action-dist * DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG). * Switch off exploration when getting action probs from off-policy-estimator's policy. * LINT * Fix test_checkpoint_restore.py. * Deprecate all SAC exploration (unused) configs. * Properly use `model.last_output()` everywhere. Instead of `model._last_output`. * WIP. * Take out set_epsilon from multi-agent-env test (not needed, decays anyway). * WIP. * Trigger re-test (flaky checkpoint-restore test). * WIP. * WIP. * Add test case for deterministic action sampling in PPO. * bug fix. * Added deterministic test cases for different Agents. * Fix problem with TupleActions in dynamic-tf-policy. * Separate supported_spaces tests so they can be run separately for easier debugging. * LINT. * Fix autoregressive_action_dist.py test case. * Re-test. * Fix. * Remove duplicate py_test rule from bazel. * LINT. * WIP. * WIP. * SAC fix. * SAC fix. * WIP. * WIP. * WIP. * FIX 2 examples tests. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Renamed test file. * WIP. * Add unittest.main. * Make action_dist_class mandatory. * fix * FIX. * WIP. * WIP. * Fix. * Fix. * Fix explorations test case (contextlib cannot find its own nullcontext??). * Force torch to be installed for QMIX. * LINT. * Fix determine_tests_to_run.py. * Fix determine_tests_to_run.py. * WIP * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Rename some stuff. * Rename some stuff. * WIP. * WIP. * Fix SAC. * Fix SAC. * Fix strange tf-error in ray core tests. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix test_io.py. * LINT. * Update SAC yaml files' config. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-02-22 23:19:49 +01:00
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
policy.q_values = q_vals
return policy.q_values, Categorical, [] # state-out
def build_q_losses(policy: Policy, model, _,
train_batch: SampleBatch) -> TensorType:
"""Constructs the loss for DQNTFPolicy.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
train_batch (SampleBatch): The training data.
Returns:
TensorType: A single loss tensor.
"""
2019-07-03 15:59:47 -07:00
config = policy.config
# q network evaluation
2021-02-25 12:18:11 +01:00
q_t, q_logits_t, q_dist_t, _ = compute_q_values(
policy,
model, {"obs": train_batch[SampleBatch.CUR_OBS]},
2021-02-25 12:18:11 +01:00
state_batches=None,
explore=False)
# target q network evalution
2021-02-25 12:18:11 +01:00
q_tp1, q_logits_tp1, q_dist_tp1, _ = compute_q_values(
policy,
policy.target_model, {"obs": train_batch[SampleBatch.NEXT_OBS]},
2021-02-25 12:18:11 +01:00
state_batches=None,
explore=False)
2021-02-25 12:18:11 +01:00
if not hasattr(policy, "target_q_func_vars"):
policy.target_q_func_vars = policy.target_model.variables()
# q scores for actions which we know were selected in the given state.
2019-07-03 15:59:47 -07:00
one_hot_selection = tf.one_hot(
tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32),
2019-07-03 15:59:47 -07:00
policy.action_space.n)
q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
q_logits_t_selected = tf.reduce_sum(
q_logits_t * tf.expand_dims(one_hot_selection, -1), 1)
# compute estimate of best possible value starting from state at t + 1
2019-07-03 15:59:47 -07:00
if config["double_q"]:
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
2021-02-25 12:18:11 +01:00
q_dist_tp1_using_online_net, _ = compute_q_values(
policy, model,
2021-02-25 12:18:11 +01:00
{"obs": train_batch[SampleBatch.NEXT_OBS]},
state_batches=None,
explore=False)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net,
policy.action_space.n)
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
q_dist_tp1_best = tf.reduce_sum(
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
else:
q_tp1_best_one_hot_selection = tf.one_hot(
tf.argmax(q_tp1, 1), policy.action_space.n)
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
q_dist_tp1_best = tf.reduce_sum(
q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1)
2019-07-03 15:59:47 -07:00
policy.q_loss = QLoss(
q_t_selected, q_logits_t_selected, q_tp1_best, q_dist_tp1_best,
train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS],
tf.cast(train_batch[SampleBatch.DONES],
tf.float32), config["gamma"], config["n_step"],
config["num_atoms"], config["v_min"], config["v_max"])
2019-07-03 15:59:47 -07:00
return policy.q_loss.loss
def adam_optimizer(policy: Policy, config: TrainerConfigDict
) -> "tf.keras.optimizers.Optimizer":
2020-07-11 22:06:35 +02:00
if policy.config["framework"] in ["tf2", "tfe"]:
return tf.keras.optimizers.Adam(
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
else:
return tf1.train.AdamOptimizer(
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer",
loss: TensorType) -> ModelGradients:
if not hasattr(policy, "q_func_vars"):
policy.q_func_vars = policy.model.variables()
return minimize_and_clip(
optimizer,
loss,
var_list=policy.q_func_vars,
clip_val=policy.config["grad_clip"])
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
return dict({
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
2019-07-03 15:59:47 -07:00
}, **policy.q_loss.stats)
def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None:
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
ComputeTDErrorMixin.__init__(policy)
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
2021-02-25 12:18:11 +01:00
def compute_q_values(policy: Policy,
model: ModelV2,
input_dict,
state_batches=None,
seq_lens=None,
explore=None,
is_training: bool = False):
config = policy.config
2021-02-25 12:18:11 +01:00
input_dict["is_training"] = policy._get_is_training_placeholder()
model_out, state = model(input_dict, state_batches or [], seq_lens)
2019-07-03 15:59:47 -07:00
if config["num_atoms"] > 1:
(action_scores, z, support_logits_per_action, logits,
dist) = model.get_q_value_distributions(model_out)
else:
(action_scores, logits,
dist) = model.get_q_value_distributions(model_out)
if config["dueling"]:
state_score = model.get_state_value(model_out)
if config["num_atoms"] > 1:
support_logits_per_action_mean = tf.reduce_mean(
support_logits_per_action, 1)
support_logits_per_action_centered = (
support_logits_per_action - tf.expand_dims(
support_logits_per_action_mean, 1))
support_logits_per_action = tf.expand_dims(
state_score, 1) + support_logits_per_action_centered
support_prob_per_action = tf.nn.softmax(
logits=support_logits_per_action)
value = tf.reduce_sum(
input_tensor=z * support_prob_per_action, axis=-1)
logits = support_logits_per_action
dist = support_prob_per_action
else:
action_scores_mean = reduce_mean_ignore_inf(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(
action_scores_mean, 1)
value = state_score + action_scores_centered
else:
value = action_scores
2021-02-25 12:18:11 +01:00
return value, logits, dist, state
def _adjust_nstep(n_step: int, gamma: int, obs: TensorType,
actions: TensorType, rewards: TensorType,
new_obs: TensorType, dones: TensorType):
[rllib] Modularize Torch and TF policy graphs (#2294) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * cast * clean up * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix
2018-06-26 13:17:15 -07:00
"""Rewrites the given trajectory fragments to encode n-step rewards.
reward[i] = (
reward[i] * gamma**0 +
reward[i+1] * gamma**1 +
... +
reward[i+n_step-1] * gamma**(n_step-1))
The ith new_obs is also adjusted to point to the (i+n_step-1)'th new obs.
At the end of the trajectory, n is truncated to fit in the traj length.
[rllib] Modularize Torch and TF policy graphs (#2294) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * cast * clean up * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix
2018-06-26 13:17:15 -07:00
"""
assert not any(dones[:-1]), "Unexpected done in middle of trajectory"
traj_length = len(rewards)
for i in range(traj_length):
[rllib] Modularize Torch and TF policy graphs (#2294) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * cast * clean up * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix
2018-06-26 13:17:15 -07:00
for j in range(1, n_step):
if i + j < traj_length:
new_obs[i] = new_obs[i + j]
dones[i] = dones[i + j]
rewards[i] += gamma**j * rewards[i + j]
[rllib] Modularize Torch and TF policy graphs (#2294) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * cast * clean up * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix
2018-06-26 13:17:15 -07:00
def postprocess_nstep_and_prio(policy: Policy,
batch: SampleBatch,
other_agent=None,
episode=None) -> SampleBatch:
2020-09-05 22:26:42 +02:00
# N-step Q adjustments.
if policy.config["n_step"] > 1:
_adjust_nstep(policy.config["n_step"], policy.config["gamma"],
batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS],
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES])
if PRIO_WEIGHTS not in batch:
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
2020-09-05 22:26:42 +02:00
# Prioritize on the worker side.
if batch.count > 0 and policy.config["worker_side_prioritization"]:
td_errors = policy.compute_td_error(
batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS],
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
new_priorities = (np.abs(convert_to_numpy(td_errors)) +
policy.config["prioritized_replay_eps"])
batch[PRIO_WEIGHTS] = new_priorities
return batch
DQNTFPolicy = build_tf_policy(
name="DQNTFPolicy",
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
2019-07-03 15:59:47 -07:00
make_model=build_q_model,
action_distribution_fn=get_distribution_inputs_and_class,
loss_fn=build_q_losses,
stats_fn=build_q_stats,
postprocess_fn=postprocess_nstep_and_prio,
optimizer_fn=adam_optimizer,
compute_gradients_fn=clip_gradients,
2021-02-25 12:18:11 +01:00
extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
2019-07-03 15:59:47 -07:00
extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
before_loss_init=setup_mid_mixins,
after_init=setup_late_mixins,
mixins=[
TargetNetworkMixin,
ComputeTDErrorMixin,
LearningRateSchedule,
])