ray/rllib/agents/a3c/a3c_tf_policy.py

148 lines
5.8 KiB
Python

"""Note: Keep in sync with changes to VTraceTFPolicy."""
from typing import Optional, Dict
import gym
import ray
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule, \
EntropyCoeffSchedule
from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import explained_variance
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
PolicyID, LocalOptimizer, ModelGradients
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.evaluation import MultiAgentEpisode
tf1, tf, tfv = try_import_tf()
@Deprecated(
old="rllib.agents.a3c.a3c_tf_policy.postprocess_advantages",
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
error=False)
def postprocess_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
return compute_gae_for_sample_batch(policy, sample_batch,
other_agent_batches, episode)
class A3CLoss:
def __init__(self,
action_dist: ActionDistribution,
actions: TensorType,
advantages: TensorType,
v_target: TensorType,
vf: TensorType,
valid_mask: TensorType,
vf_loss_coeff: float = 0.5,
entropy_coeff: float = 0.01,
use_critic: bool = True):
log_prob = action_dist.logp(actions)
# The "policy gradients" loss
self.pi_loss = -tf.reduce_sum(
tf.boolean_mask(log_prob * advantages, valid_mask))
delta = tf.boolean_mask(vf - v_target, valid_mask)
# Compute a value function loss.
if use_critic:
self.vf_loss = 0.5 * tf.reduce_sum(tf.math.square(delta))
# Ignore the value function.
else:
self.vf_loss = tf.constant(0.0)
self.entropy = tf.reduce_sum(
tf.boolean_mask(action_dist.entropy(), valid_mask))
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
self.entropy * entropy_coeff)
def actor_critic_loss(policy: Policy, model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if policy.is_recurrent():
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(train_batch[SampleBatch.REWARDS])
policy.loss = A3CLoss(action_dist, train_batch[SampleBatch.ACTIONS],
train_batch[Postprocessing.ADVANTAGES],
train_batch[Postprocessing.VALUE_TARGETS],
model.value_function(), mask,
policy.config["vf_loss_coeff"], policy.entropy_coeff,
policy.config.get("use_critic", True))
return policy.loss.total_loss
def add_value_function_fetch(policy: Policy) -> Dict[str, TensorType]:
return {SampleBatch.VF_PREDS: policy.model.value_function()}
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64),
"policy_loss": policy.loss.pi_loss,
"policy_entropy": policy.loss.entropy,
"var_gnorm": tf.linalg.global_norm(
list(policy.model.trainable_variables())),
"vf_loss": policy.loss.vf_loss,
}
def grad_stats(policy: Policy, train_batch: SampleBatch,
grads: ModelGradients) -> Dict[str, TensorType]:
return {
"grad_gnorm": tf.linalg.global_norm(grads),
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function())
}
def clip_gradients(policy: Policy, optimizer: LocalOptimizer,
loss: TensorType) -> ModelGradients:
grads_and_vars = optimizer.compute_gradients(
loss, policy.model.trainable_variables())
grads = [g for (g, v) in grads_and_vars]
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
clipped_grads = list(zip(grads, policy.model.trainable_variables()))
return clipped_grads
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
config["entropy_coeff_schedule"])
A3CTFPolicy = build_tf_policy(
name="A3CTFPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=stats,
grad_stats_fn=grad_stats,
compute_gradients_fn=clip_gradients,
postprocess_fn=compute_gae_for_sample_batch,
extra_action_out_fn=add_value_function_fetch,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule])