ray/rllib/agents/dqn/simple_q_torch_policy.py

101 lines
3.5 KiB
Python
Raw Normal View History

[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
"""Basic example of a DQN policy without any optimizations."""
import logging
import ray
from ray.rllib.agents.dqn.simple_q_tf_policy import build_q_models, \
get_distribution_inputs_and_class, compute_q_values
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
from ray.rllib.utils.torch_ops import huber_loss
torch, nn = try_import_torch()
F = None
if nn:
F = nn.functional
logger = logging.getLogger(__name__)
class TargetNetworkMixin:
def __init__(self, obs_space, action_space, config):
def do_update():
# Update_target_fn will be called periodically to copy Q network to
# target Q network.
assert len(self.q_func_vars) == len(self.target_q_func_vars), \
(self.q_func_vars, self.target_q_func_vars)
self.target_q_model.load_state_dict(self.q_model.state_dict())
self.update_target = do_update
def build_q_model_and_distribution(policy, obs_space, action_space, config):
return build_q_models(policy, obs_space, action_space, config), \
TorchCategorical
def build_q_losses(policy, model, dist_class, train_batch):
# q network evaluation
q_t = compute_q_values(
policy,
policy.q_model,
train_batch[SampleBatch.CUR_OBS],
explore=False,
is_training=True)
# target q network evalution
q_tp1 = compute_q_values(
policy,
policy.target_q_model,
train_batch[SampleBatch.NEXT_OBS],
explore=False,
is_training=True)
# q scores for actions which we know were selected in the given state.
one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
policy.action_space.n)
q_t_selected = torch.sum(q_t * one_hot_selection, 1)
# compute estimate of best possible value starting from state at t + 1
dones = train_batch[SampleBatch.DONES].float()
q_tp1_best_one_hot_selection = F.one_hot(
torch.argmax(q_tp1, 1), policy.action_space.n)
q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
q_tp1_best_masked = (1.0 - dones) * q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
policy.config["gamma"] * q_tp1_best_masked)
# Compute the error (Square/Huber).
td_error = q_t_selected - q_t_selected_target.detach()
loss = torch.mean(huber_loss(td_error))
# save TD error as an attribute for outside access
policy.td_error = td_error
return loss
def extra_action_out_fn(policy, input_dict, state_batches, model, action_dist):
"""Adds q-values to action out dict."""
return {"q_values": policy.q_values}
def setup_late_mixins(policy, obs_space, action_space, config):
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
SimpleQTorchPolicy = build_torch_policy(
name="SimpleQPolicy",
loss_fn=build_q_losses,
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
extra_action_out_fn=extra_action_out_fn,
after_init=setup_late_mixins,
make_model_and_action_dist=build_q_model_and_distribution,
mixins=[TargetNetworkMixin],
action_distribution_fn=get_distribution_inputs_and_class,
stats_fn=lambda policy, config: {"td_error": policy.td_error},
)