2019-05-20 16:46:05 -07:00
|
|
|
"""Note: Keep in sync with changes to VTraceTFPolicy."""
|
2021-05-19 07:32:29 -07:00
|
|
|
from typing import Optional, Dict
|
|
|
|
import gym
|
2018-08-01 20:53:53 -07:00
|
|
|
|
2018-06-25 22:33:57 -07:00
|
|
|
import ray
|
2021-01-19 14:22:36 +01:00
|
|
|
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2021-01-19 14:22:36 +01:00
|
|
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
2019-03-29 12:44:23 -07:00
|
|
|
Postprocessing
|
2019-06-02 14:14:31 +08:00
|
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
|
|
|
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
2021-08-03 18:30:02 -04:00
|
|
|
from ray.rllib.utils.annotations import Deprecated
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2021-01-19 14:22:36 +01:00
|
|
|
from ray.rllib.utils.tf_ops import explained_variance
|
2021-05-19 07:32:29 -07:00
|
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
|
|
|
|
PolicyID, LocalOptimizer, ModelGradients
|
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
|
|
from ray.rllib.evaluation import MultiAgentEpisode
|
2019-05-10 20:36:18 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2018-06-09 00:21:35 -07:00
|
|
|
|
|
|
|
|
2021-08-03 18:30:02 -04:00
|
|
|
@Deprecated(
|
|
|
|
old="rllib.agents.a3c.a3c_tf_policy.postprocess_advantages",
|
|
|
|
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
|
|
|
|
error=False)
|
2021-05-19 07:32:29 -07:00
|
|
|
def postprocess_advantages(
|
|
|
|
policy: Policy,
|
|
|
|
sample_batch: SampleBatch,
|
|
|
|
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
|
|
|
|
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
2021-01-19 14:22:36 +01:00
|
|
|
|
|
|
|
return compute_gae_for_sample_batch(policy, sample_batch,
|
|
|
|
other_agent_batches, episode)
|
|
|
|
|
|
|
|
|
2020-01-02 17:42:13 -08:00
|
|
|
class A3CLoss:
|
2018-07-19 15:30:36 -07:00
|
|
|
def __init__(self,
|
2021-05-19 07:32:29 -07:00
|
|
|
action_dist: ActionDistribution,
|
|
|
|
actions: TensorType,
|
|
|
|
advantages: TensorType,
|
|
|
|
v_target: TensorType,
|
|
|
|
vf: TensorType,
|
|
|
|
valid_mask: TensorType,
|
|
|
|
vf_loss_coeff: float = 0.5,
|
|
|
|
entropy_coeff: float = 0.01,
|
|
|
|
use_critic: bool = True):
|
2018-06-26 13:17:15 -07:00
|
|
|
log_prob = action_dist.logp(actions)
|
2018-06-09 00:21:35 -07:00
|
|
|
|
2018-06-26 13:17:15 -07:00
|
|
|
# The "policy gradients" loss
|
2021-04-27 02:36:04 -04:00
|
|
|
self.pi_loss = -tf.reduce_sum(
|
|
|
|
tf.boolean_mask(log_prob * advantages, valid_mask))
|
2018-06-09 00:21:35 -07:00
|
|
|
|
2021-04-27 02:36:04 -04:00
|
|
|
delta = tf.boolean_mask(vf - v_target, valid_mask)
|
2021-05-04 14:17:00 +02:00
|
|
|
|
|
|
|
# Compute a value function loss.
|
|
|
|
if use_critic:
|
|
|
|
self.vf_loss = 0.5 * tf.reduce_sum(tf.math.square(delta))
|
|
|
|
# Ignore the value function.
|
|
|
|
else:
|
|
|
|
self.vf_loss = tf.constant(0.0)
|
|
|
|
|
2021-04-27 02:36:04 -04:00
|
|
|
self.entropy = tf.reduce_sum(
|
|
|
|
tf.boolean_mask(action_dist.entropy(), valid_mask))
|
2021-05-04 14:17:00 +02:00
|
|
|
|
2019-03-17 18:07:37 -07:00
|
|
|
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
2018-06-26 13:17:15 -07:00
|
|
|
self.entropy * entropy_coeff)
|
2018-06-09 00:21:35 -07:00
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def actor_critic_loss(policy: Policy, model: ModelV2,
|
|
|
|
dist_class: ActionDistribution,
|
|
|
|
train_batch: SampleBatch) -> TensorType:
|
2019-08-23 02:21:11 -04:00
|
|
|
model_out, _ = model.from_batch(train_batch)
|
|
|
|
action_dist = dist_class(model_out, model)
|
2021-04-27 02:36:04 -04:00
|
|
|
if policy.is_recurrent():
|
2021-08-21 17:05:48 +02:00
|
|
|
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
|
|
|
|
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
|
2021-04-27 02:36:04 -04:00
|
|
|
mask = tf.reshape(mask, [-1])
|
|
|
|
else:
|
|
|
|
mask = tf.ones_like(train_batch[SampleBatch.REWARDS])
|
2019-08-23 02:21:11 -04:00
|
|
|
policy.loss = A3CLoss(action_dist, train_batch[SampleBatch.ACTIONS],
|
|
|
|
train_batch[Postprocessing.ADVANTAGES],
|
|
|
|
train_batch[Postprocessing.VALUE_TARGETS],
|
2021-04-27 02:36:04 -04:00
|
|
|
model.value_function(), mask,
|
2019-08-23 02:21:11 -04:00
|
|
|
policy.config["vf_loss_coeff"],
|
2021-05-04 14:17:00 +02:00
|
|
|
policy.config["entropy_coeff"],
|
|
|
|
policy.config.get("use_critic", True))
|
2019-06-02 14:14:31 +08:00
|
|
|
return policy.loss.total_loss
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def add_value_function_fetch(policy: Policy) -> Dict[str, TensorType]:
|
2019-08-23 02:21:11 -04:00
|
|
|
return {SampleBatch.VF_PREDS: policy.model.value_function()}
|
2019-06-02 14:14:31 +08:00
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
2019-06-02 14:14:31 +08:00
|
|
|
return {
|
2019-07-21 12:27:17 -07:00
|
|
|
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
2019-06-02 14:14:31 +08:00
|
|
|
"policy_loss": policy.loss.pi_loss,
|
|
|
|
"policy_entropy": policy.loss.entropy,
|
2020-06-25 19:01:32 +02:00
|
|
|
"var_gnorm": tf.linalg.global_norm(
|
|
|
|
list(policy.model.trainable_variables())),
|
2019-06-02 14:14:31 +08:00
|
|
|
"vf_loss": policy.loss.vf_loss,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def grad_stats(policy: Policy, train_batch: SampleBatch,
|
|
|
|
grads: ModelGradients) -> Dict[str, TensorType]:
|
2019-06-02 14:14:31 +08:00
|
|
|
return {
|
2020-06-25 19:01:32 +02:00
|
|
|
"grad_gnorm": tf.linalg.global_norm(grads),
|
2019-06-02 14:14:31 +08:00
|
|
|
"vf_explained_var": explained_variance(
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch[Postprocessing.VALUE_TARGETS],
|
|
|
|
policy.model.value_function()),
|
2019-06-02 14:14:31 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def clip_gradients(policy: Policy, optimizer: LocalOptimizer,
|
|
|
|
loss: TensorType) -> ModelGradients:
|
2019-08-23 02:21:11 -04:00
|
|
|
grads_and_vars = optimizer.compute_gradients(
|
|
|
|
loss, policy.model.trainable_variables())
|
|
|
|
grads = [g for (g, v) in grads_and_vars]
|
2019-06-02 14:14:31 +08:00
|
|
|
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
|
2019-08-23 02:21:11 -04:00
|
|
|
clipped_grads = list(zip(grads, policy.model.trainable_variables()))
|
2019-06-02 14:14:31 +08:00
|
|
|
return clipped_grads
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict) -> None:
|
2021-01-19 14:22:36 +01:00
|
|
|
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
2019-06-02 14:14:31 +08:00
|
|
|
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
|
|
|
|
|
|
|
|
|
|
|
A3CTFPolicy = build_tf_policy(
|
|
|
|
name="A3CTFPolicy",
|
|
|
|
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
|
|
|
|
loss_fn=actor_critic_loss,
|
|
|
|
stats_fn=stats,
|
|
|
|
grad_stats_fn=grad_stats,
|
2021-05-18 11:10:46 +02:00
|
|
|
compute_gradients_fn=clip_gradients,
|
2021-01-19 14:22:36 +01:00
|
|
|
postprocess_fn=compute_gae_for_sample_batch,
|
2021-02-25 12:18:11 +01:00
|
|
|
extra_action_out_fn=add_value_function_fetch,
|
2019-06-02 14:14:31 +08:00
|
|
|
before_loss_init=setup_mixins,
|
|
|
|
mixins=[ValueNetworkMixin, LearningRateSchedule])
|