mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
100 lines
3.5 KiB
Python
100 lines
3.5 KiB
Python
"""Basic example of a DQN policy without any optimizations."""
|
|
|
|
import logging
|
|
|
|
import ray
|
|
from ray.rllib.agents.dqn.simple_q_tf_policy import build_q_models, \
|
|
get_distribution_inputs_and_class, compute_q_values
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
|
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.torch_ops import huber_loss
|
|
|
|
torch, nn = try_import_torch()
|
|
F = None
|
|
if nn:
|
|
F = nn.functional
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TargetNetworkMixin:
|
|
def __init__(self, obs_space, action_space, config):
|
|
def do_update():
|
|
# Update_target_fn will be called periodically to copy Q network to
|
|
# target Q network.
|
|
assert len(self.q_func_vars) == len(self.target_q_func_vars), \
|
|
(self.q_func_vars, self.target_q_func_vars)
|
|
self.target_q_model.load_state_dict(self.q_model.state_dict())
|
|
|
|
self.update_target = do_update
|
|
|
|
|
|
def build_q_model_and_distribution(policy, obs_space, action_space, config):
|
|
return build_q_models(policy, obs_space, action_space, config), \
|
|
TorchCategorical
|
|
|
|
|
|
def build_q_losses(policy, model, dist_class, train_batch):
|
|
# q network evaluation
|
|
q_t = compute_q_values(
|
|
policy,
|
|
policy.q_model,
|
|
train_batch[SampleBatch.CUR_OBS],
|
|
explore=False,
|
|
is_training=True)
|
|
|
|
# target q network evalution
|
|
q_tp1 = compute_q_values(
|
|
policy,
|
|
policy.target_q_model,
|
|
train_batch[SampleBatch.NEXT_OBS],
|
|
explore=False,
|
|
is_training=True)
|
|
|
|
# q scores for actions which we know were selected in the given state.
|
|
one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
|
|
policy.action_space.n)
|
|
q_t_selected = torch.sum(q_t * one_hot_selection, 1)
|
|
|
|
# compute estimate of best possible value starting from state at t + 1
|
|
dones = train_batch[SampleBatch.DONES].float()
|
|
q_tp1_best_one_hot_selection = F.one_hot(
|
|
torch.argmax(q_tp1, 1), policy.action_space.n)
|
|
q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
|
q_tp1_best_masked = (1.0 - dones) * q_tp1_best
|
|
|
|
# compute RHS of bellman equation
|
|
q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
|
|
policy.config["gamma"] * q_tp1_best_masked)
|
|
|
|
# Compute the error (Square/Huber).
|
|
td_error = q_t_selected - q_t_selected_target.detach()
|
|
loss = torch.mean(huber_loss(td_error))
|
|
|
|
# save TD error as an attribute for outside access
|
|
policy.td_error = td_error
|
|
|
|
return loss
|
|
|
|
|
|
def extra_action_out_fn(policy, input_dict, state_batches, model, action_dist):
|
|
"""Adds q-values to action out dict."""
|
|
return {"q_values": policy.q_values}
|
|
|
|
|
|
def setup_late_mixins(policy, obs_space, action_space, config):
|
|
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
|
|
|
|
|
SimpleQTorchPolicy = build_torch_policy(
|
|
name="SimpleQPolicy",
|
|
loss_fn=build_q_losses,
|
|
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
|
|
extra_action_out_fn=extra_action_out_fn,
|
|
after_init=setup_late_mixins,
|
|
make_model_and_action_dist=build_q_model_and_distribution,
|
|
mixins=[TargetNetworkMixin],
|
|
action_distribution_fn=get_distribution_inputs_and_class,
|
|
stats_fn=lambda policy, config: {"td_error": policy.td_error},
|
|
)
|