mirror of
synced 2025-03-06 10:31:39 -05:00

* Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
343 lines
13 KiB
343 lines
13 KiB
from gym.spaces import Discrete
import logging
import ray
import ray.experimental.tf_utils
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.sac.sac_tf_policy import build_sac_model, \
from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.models.torch.torch_action_dist import (
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
from ray.rllib.utils import try_import_torch
torch, nn = try_import_torch()
F = nn.functional
logger = logging.getLogger(__name__)
def build_sac_model_and_action_dist(policy, obs_space, action_space, config):
model = build_sac_model(policy, obs_space, action_space, config)
action_dist_class = get_dist_class(config, action_space)
return model, action_dist_class
def get_dist_class(config, action_space):
if isinstance(action_space, Discrete):
return TorchCategorical
if config["normalize_actions"]:
return TorchSquashedGaussian if \
not config["_use_beta_distribution"] else TorchBeta
return TorchDiagGaussian
def action_distribution_fn(policy,
model_out, _ = model({
"obs": obs_batch,
"is_training": is_training,
}, [], None)
distribution_inputs = model.get_policy_output(model_out)
action_dist_class = get_dist_class(policy.config, policy.action_space)
return distribution_inputs, action_dist_class, []
def actor_critic_loss(policy, model, _, train_batch):
# Should be True only for debugging purposes (e.g. test cases)!
deterministic = policy.config["_deterministic_loss"]
model_out_t, _ = model({
"obs": train_batch[SampleBatch.CUR_OBS],
"is_training": True,
}, [], None)
model_out_tp1, _ = model({
"obs": train_batch[SampleBatch.NEXT_OBS],
"is_training": True,
}, [], None)
target_model_out_tp1, _ = policy.target_model({
"obs": train_batch[SampleBatch.NEXT_OBS],
"is_training": True,
}, [], None)
alpha = torch.exp(model.log_alpha)
# Discrete case.
if model.discrete:
# Get all action probs directly from pi and form their logp.
log_pis_t = F.log_softmax(model.get_policy_output(model_out_t), dim=-1)
policy_t = torch.exp(log_pis_t)
log_pis_tp1 = F.log_softmax(model.get_policy_output(model_out_tp1), -1)
policy_tp1 = torch.exp(log_pis_tp1)
# Q-values.
q_t = model.get_q_values(model_out_t)
# Target Q-values.
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1)
if policy.config["twin_q"]:
twin_q_t = model.get_twin_q_values(model_out_t)
twin_q_tp1 = policy.target_model.get_twin_q_values(
q_tp1 = torch.min(q_tp1, twin_q_tp1)
q_tp1 -= alpha * log_pis_tp1
# Actually selected Q-values (from the actions batch).
one_hot = F.one_hot(
train_batch[SampleBatch.ACTIONS], num_classes=q_t.size()[-1])
q_t_selected = torch.sum(q_t * one_hot, dim=-1)
if policy.config["twin_q"]:
twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)
# Discrete case: "Best" means weighted by the policy (prob) outputs.
q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1)
q_tp1_best_masked = \
(1.0 - train_batch[SampleBatch.DONES].float()) * \
# Continuous actions case.
# Sample single actions from distribution.
action_dist_class = get_dist_class(policy.config, policy.action_space)
action_dist_t = action_dist_class(
model.get_policy_output(model_out_t), policy.model)
policy_t = action_dist_t.sample() if not deterministic else \
log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
action_dist_tp1 = action_dist_class(
model.get_policy_output(model_out_tp1), policy.model)
policy_tp1 = action_dist_tp1.sample() if not deterministic else \
log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1)
# Q-values for the actually selected actions.
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
if policy.config["twin_q"]:
twin_q_t = model.get_twin_q_values(
model_out_t, train_batch[SampleBatch.ACTIONS])
# Q-values for current policy in given current state.
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
if policy.config["twin_q"]:
twin_q_t_det_policy = model.get_twin_q_values(
model_out_t, policy_t)
q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy)
# Target q network evaluation.
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
if policy.config["twin_q"]:
twin_q_tp1 = policy.target_model.get_twin_q_values(
target_model_out_tp1, policy_tp1)
# Take min over both twin-NNs.
q_tp1 = torch.min(q_tp1, twin_q_tp1)
q_t_selected = torch.squeeze(q_t, dim=-1)
if policy.config["twin_q"]:
twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
q_tp1 -= alpha * log_pis_tp1
q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
q_tp1_best_masked = (1.0 - train_batch[SampleBatch.DONES].float()) * \
assert policy.config["n_step"] == 1, "TODO(hartikainen) n_step > 1"
# compute RHS of bellman equation
q_t_selected_target = (
train_batch[SampleBatch.REWARDS] +
(policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked
# Compute the TD-error (potentially clipped).
base_td_error = torch.abs(q_t_selected - q_t_selected_target)
if policy.config["twin_q"]:
twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target)
td_error = 0.5 * (base_td_error + twin_td_error)
td_error = base_td_error
critic_loss = [
0.5 * torch.mean(torch.pow(q_t_selected_target - q_t_selected, 2.0))
if policy.config["twin_q"]:
critic_loss.append(0.5 * torch.mean(
torch.pow(q_t_selected_target - twin_q_t_selected, 2.0)))
# Alpha- and actor losses.
# Note: In the papers, alpha is used directly, here we take the log.
# Discrete case: Multiply the action probs as weights with the original
# loss terms (no expectations needed).
if model.discrete:
weighted_log_alpha_loss = policy_t.detach() * (
-model.log_alpha * (log_pis_t + model.target_entropy).detach())
# Sum up weighted terms and mean over all batch items.
alpha_loss = torch.mean(torch.sum(weighted_log_alpha_loss, dim=-1))
# Actor loss.
actor_loss = torch.mean(
# NOTE: No stop_grad around policy output here
# (compare with q_t_det_policy for continuous case).
alpha.detach() * log_pis_t - q_t.detach()),
alpha_loss = -torch.mean(model.log_alpha *
(log_pis_t + model.target_entropy).detach())
# Note: Do not detach q_t_det_policy here b/c is depends partly
# on the policy vars (policy sample pushed through Q-net).
# However, we must make sure `actor_loss` is not used to update
# the Q-net(s)' variables.
actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy)
# Save for stats function.
policy.q_t = q_t
policy.policy_t = policy_t
policy.log_pis_t = log_pis_t
policy.td_error = td_error
policy.actor_loss = actor_loss
policy.critic_loss = critic_loss
policy.alpha_loss = alpha_loss
policy.log_alpha_value = model.log_alpha
policy.alpha_value = alpha
policy.target_entropy = model.target_entropy
# Return all loss terms corresponding to our optimizers.
return tuple([policy.actor_loss] + policy.critic_loss +
def stats(policy, train_batch):
return {
"td_error": policy.td_error,
"mean_td_error": torch.mean(policy.td_error),
"actor_loss": torch.mean(policy.actor_loss),
"critic_loss": torch.mean(torch.stack(policy.critic_loss)),
"alpha_loss": torch.mean(policy.alpha_loss),
"alpha_value": torch.mean(policy.alpha_value),
"log_alpha_value": torch.mean(policy.log_alpha_value),
"target_entropy": policy.target_entropy,
"policy_t": torch.mean(policy.policy_t),
"mean_q": torch.mean(policy.q_t),
"max_q": torch.max(policy.q_t),
"min_q": torch.min(policy.q_t),
def optimizer_fn(policy, config):
"""Creates all necessary optimizers for SAC learning.
The 3 or 4 (twin_q=True) optimizers returned here correspond to the
number of loss terms returned by the loss function.
policy.actor_optim = torch.optim.Adam(
eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default
critic_split = len(policy.model.q_variables())
if config["twin_q"]:
critic_split //= 2
policy.critic_optims = [
eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default
if config["twin_q"]:
eps=1e-7, # to match tf.keras.optimizers.Adam's eps default
policy.alpha_optim = torch.optim.Adam(
eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default
return tuple([policy.actor_optim] + policy.critic_optims +
class ComputeTDErrorMixin:
def __init__(self):
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
input_dict = self._lazy_tensor_dict({
SampleBatch.CUR_OBS: obs_t,
SampleBatch.ACTIONS: act_t,
SampleBatch.REWARDS: rew_t,
SampleBatch.NEXT_OBS: obs_tp1,
SampleBatch.DONES: done_mask,
PRIO_WEIGHTS: importance_weights,
# Do forward pass on loss to update td errors attribute
# (one TD-error value per item in batch to update PR weights).
actor_critic_loss(self, self.model, None, input_dict)
# Self.td_error is set within actor_critic_loss call.
return self.td_error
self.compute_td_error = compute_td_error
class TargetNetworkMixin:
def __init__(self):
# Hard initial update from Q-net(s) to target Q-net(s).
def update_target(self, tau=None):
# Update_target_fn will be called periodically to copy Q network to
# target Q network, using (soft) tau-synching.
tau = tau or self.config.get("tau")
model_state_dict = self.model.state_dict()
# Support partial (soft) synching.
# If tau == 1.0: Full sync from Q-model to target Q-model.
if tau != 1.0:
target_state_dict = self.target_model.state_dict()
model_state_dict = {
k: tau * model_state_dict[k] + (1 - tau) * v
for k, v in target_state_dict.items()
def setup_late_mixins(policy, obs_space, action_space, config):
policy.target_model = policy.target_model.to(policy.device)
policy.model.log_alpha = policy.model.log_alpha.to(policy.device)
policy.model.target_entropy = policy.model.target_entropy.to(policy.device)
SACTorchPolicy = build_torch_policy(
get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG,
mixins=[TargetNetworkMixin, ComputeTDErrorMixin],