ray/rllib/agents/a3c/a3c_tf_policy.py

137 lines
4.9 KiB
Python
Raw Normal View History

"""Note: Keep in sync with changes to VTraceTFPolicy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.utils.tf_ops import make_tf_callable
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
[rllib] Modularize Torch and TF policy graphs (#2294) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * cast * clean up * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix
2018-06-26 13:17:15 -07:00
class A3CLoss(object):
def __init__(self,
action_dist,
actions,
advantages,
v_target,
vf,
vf_loss_coeff=0.5,
entropy_coeff=0.01):
[rllib] Modularize Torch and TF policy graphs (#2294) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * cast * clean up * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix
2018-06-26 13:17:15 -07:00
log_prob = action_dist.logp(actions)
[rllib] Modularize Torch and TF policy graphs (#2294) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * cast * clean up * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix
2018-06-26 13:17:15 -07:00
# The "policy gradients" loss
self.pi_loss = -tf.reduce_sum(log_prob * advantages)
[rllib] Modularize Torch and TF policy graphs (#2294) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * cast * clean up * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix
2018-06-26 13:17:15 -07:00
delta = vf - v_target
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
self.entropy = tf.reduce_sum(action_dist.entropy())
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
[rllib] Modularize Torch and TF policy graphs (#2294) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * cast * clean up * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix
2018-06-26 13:17:15 -07:00
self.entropy * entropy_coeff)
def actor_critic_loss(policy, model, dist_class, train_batch):
model_out, _ = model.from_batch(train_batch)
action_dist = dist_class(model_out, model)
policy.loss = A3CLoss(action_dist, train_batch[SampleBatch.ACTIONS],
train_batch[Postprocessing.ADVANTAGES],
train_batch[Postprocessing.VALUE_TARGETS],
model.value_function(),
policy.config["vf_loss_coeff"],
policy.config["entropy_coeff"])
return policy.loss.total_loss
def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
policy.config["lambda"])
def add_value_function_fetch(policy):
return {SampleBatch.VF_PREDS: policy.model.value_function()}
class ValueNetworkMixin(object):
def __init__(self):
@make_tf_callable(self.get_session())
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor([prev_action]),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor([prev_reward]),
"is_training": tf.convert_to_tensor(False),
}, [tf.convert_to_tensor([s]) for s in state],
tf.convert_to_tensor([1]))
return self.model.value_function()[0]
self._value = value
def stats(policy, train_batch):
return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"policy_loss": policy.loss.pi_loss,
"policy_entropy": policy.loss.entropy,
"var_gnorm": tf.global_norm(
2019-11-16 10:02:58 -08:00
list(policy.model.trainable_variables())),
"vf_loss": policy.loss.vf_loss,
}
def grad_stats(policy, train_batch, grads):
return {
"grad_gnorm": tf.global_norm(grads),
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function()),
}
def clip_gradients(policy, optimizer, loss):
grads_and_vars = optimizer.compute_gradients(
loss, policy.model.trainable_variables())
grads = [g for (g, v) in grads_and_vars]
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
clipped_grads = list(zip(grads, policy.model.trainable_variables()))
return clipped_grads
def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy)
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
A3CTFPolicy = build_tf_policy(
name="A3CTFPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=stats,
grad_stats_fn=grad_stats,
gradients_fn=clip_gradients,
postprocess_fn=postprocess_advantages,
extra_action_fetches_fn=add_value_function_fetch,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin, LearningRateSchedule])