mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
162 lines
5.7 KiB
Python
162 lines
5.7 KiB
Python
"""PyTorch policy class used for Simple Q-Learning"""
|
|
|
|
import logging
|
|
from typing import Dict, Tuple
|
|
|
|
import gym
|
|
import ray
|
|
from ray.rllib.agents.dqn.simple_q_tf_policy import (
|
|
build_q_models, compute_q_values, get_distribution_inputs_and_class)
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
|
TorchDistributionWrapper
|
|
from ray.rllib.policy import Policy
|
|
from ray.rllib.policy.policy_template import build_policy_class
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.torch_policy import TorchPolicy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.torch_utils import concat_multi_gpu_td_errors, huber_loss
|
|
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
|
|
|
torch, nn = try_import_torch()
|
|
F = None
|
|
if nn:
|
|
F = nn.functional
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TargetNetworkMixin:
|
|
"""Assign the `update_target` method to the SimpleQTorchPolicy
|
|
|
|
The function is called every `target_network_update_freq` steps by the
|
|
master learner.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Hard initial update from Q-net(s) to target Q-net(s).
|
|
self.update_target()
|
|
|
|
def update_target(self):
|
|
# Update_target_fn will be called periodically to copy Q network to
|
|
# target Q networks.
|
|
state_dict = self.model.state_dict()
|
|
for target in self.target_models.values():
|
|
target.load_state_dict(state_dict)
|
|
|
|
@override(TorchPolicy)
|
|
def set_weights(self, weights):
|
|
# Makes sure that whenever we restore weights for this policy's
|
|
# model, we sync the target network (from the main model)
|
|
# at the same time.
|
|
TorchPolicy.set_weights(self, weights)
|
|
self.update_target()
|
|
|
|
|
|
def build_q_model_and_distribution(
|
|
policy: Policy, obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
|
|
return build_q_models(policy, obs_space, action_space, config), \
|
|
TorchCategorical
|
|
|
|
|
|
def build_q_losses(policy: Policy, model, dist_class,
|
|
train_batch: SampleBatch) -> TensorType:
|
|
"""Constructs the loss for SimpleQTorchPolicy.
|
|
|
|
Args:
|
|
policy (Policy): The Policy to calculate the loss for.
|
|
model (ModelV2): The Model to calculate the loss for.
|
|
dist_class (Type[ActionDistribution]): The action distribution class.
|
|
train_batch (SampleBatch): The training data.
|
|
|
|
Returns:
|
|
TensorType: A single loss tensor.
|
|
"""
|
|
target_model = policy.target_models[model]
|
|
|
|
# q network evaluation
|
|
q_t = compute_q_values(
|
|
policy,
|
|
model,
|
|
train_batch[SampleBatch.CUR_OBS],
|
|
explore=False,
|
|
is_training=True)
|
|
|
|
# target q network evalution
|
|
q_tp1 = compute_q_values(
|
|
policy,
|
|
target_model,
|
|
train_batch[SampleBatch.NEXT_OBS],
|
|
explore=False,
|
|
is_training=True)
|
|
|
|
# q scores for actions which we know were selected in the given state.
|
|
one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
|
|
policy.action_space.n)
|
|
q_t_selected = torch.sum(q_t * one_hot_selection, 1)
|
|
|
|
# compute estimate of best possible value starting from state at t + 1
|
|
dones = train_batch[SampleBatch.DONES].float()
|
|
q_tp1_best_one_hot_selection = F.one_hot(
|
|
torch.argmax(q_tp1, 1), policy.action_space.n)
|
|
q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
|
q_tp1_best_masked = (1.0 - dones) * q_tp1_best
|
|
|
|
# compute RHS of bellman equation
|
|
q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
|
|
policy.config["gamma"] * q_tp1_best_masked)
|
|
|
|
# Compute the error (Square/Huber).
|
|
td_error = q_t_selected - q_t_selected_target.detach()
|
|
loss = torch.mean(huber_loss(td_error))
|
|
|
|
# Store values for stats function in model (tower), such that for
|
|
# multi-GPU, we do not override them during the parallel loss phase.
|
|
model.tower_stats["loss"] = loss
|
|
# TD-error tensor in final stats
|
|
# will be concatenated and retrieved for each individual batch item.
|
|
model.tower_stats["td_error"] = td_error
|
|
|
|
return loss
|
|
|
|
|
|
def stats_fn(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]:
|
|
return {"loss": torch.mean(torch.stack(policy.get_tower_stats("loss")))}
|
|
|
|
|
|
def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
|
|
action_dist) -> Dict[str, TensorType]:
|
|
"""Adds q-values to the action out dict."""
|
|
return {"q_values": policy.q_values}
|
|
|
|
|
|
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: TrainerConfigDict) -> None:
|
|
"""Call all mixin classes' constructors before SimpleQTorchPolicy
|
|
initialization.
|
|
|
|
Args:
|
|
policy (Policy): The Policy object.
|
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
|
action_space (gym.spaces.Space): The Policy's action space.
|
|
config (TrainerConfigDict): The Policy's config.
|
|
"""
|
|
TargetNetworkMixin.__init__(policy)
|
|
|
|
|
|
SimpleQTorchPolicy = build_policy_class(
|
|
name="SimpleQPolicy",
|
|
framework="torch",
|
|
loss_fn=build_q_losses,
|
|
get_default_config=lambda: ray.rllib.agents.dqn.simple_q.DEFAULT_CONFIG,
|
|
stats_fn=stats_fn,
|
|
extra_action_out_fn=extra_action_out_fn,
|
|
after_init=setup_late_mixins,
|
|
make_model_and_action_dist=build_q_model_and_distribution,
|
|
mixins=[TargetNetworkMixin],
|
|
action_distribution_fn=get_distribution_inputs_and_class,
|
|
extra_learn_fetches_fn=concat_multi_gpu_td_errors,
|
|
)
|