mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
176 lines
6.6 KiB
Python
176 lines
6.6 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import ray
|
|
from ray.rllib.models import ModelCatalog
|
|
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
|
Postprocessing
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.tf_policy import TFPolicy
|
|
from ray.rllib.utils.explained_variance import explained_variance
|
|
from ray.rllib.utils import try_import_tf
|
|
from ray.rllib.utils.tf_ops import scope_vars
|
|
|
|
tf = try_import_tf()
|
|
|
|
POLICY_SCOPE = "p_func"
|
|
VALUE_SCOPE = "v_func"
|
|
|
|
|
|
class ValueLoss:
|
|
def __init__(self, state_values, cumulative_rewards):
|
|
self.loss = 0.5 * tf.reduce_mean(
|
|
tf.square(state_values - cumulative_rewards))
|
|
|
|
|
|
class ReweightedImitationLoss:
|
|
def __init__(self, state_values, cumulative_rewards, logits, actions,
|
|
action_space, beta, model):
|
|
ma_adv_norm = tf.get_variable(
|
|
name="moving_average_of_advantage_norm",
|
|
dtype=tf.float32,
|
|
initializer=100.0,
|
|
trainable=False)
|
|
# advantage estimation
|
|
adv = cumulative_rewards - state_values
|
|
# update averaged advantage norm
|
|
update_adv_norm = tf.assign_add(
|
|
ref=ma_adv_norm,
|
|
value=1e-6 * (tf.reduce_mean(tf.square(adv)) - ma_adv_norm))
|
|
|
|
# exponentially weighted advantages
|
|
with tf.control_dependencies([update_adv_norm]):
|
|
exp_advs = tf.exp(
|
|
beta * tf.divide(adv, 1e-8 + tf.sqrt(ma_adv_norm)))
|
|
|
|
# log\pi_\theta(a|s)
|
|
dist_class, _ = ModelCatalog.get_action_dist(action_space, {})
|
|
action_dist = dist_class(logits, model)
|
|
logprobs = action_dist.logp(actions)
|
|
|
|
self.loss = -1.0 * tf.reduce_mean(
|
|
tf.stop_gradient(exp_advs) * logprobs)
|
|
|
|
|
|
class MARWILPostprocessing:
|
|
"""Adds the advantages field to the trajectory."""
|
|
|
|
@override(Policy)
|
|
def postprocess_trajectory(self,
|
|
sample_batch,
|
|
other_agent_batches=None,
|
|
episode=None):
|
|
completed = sample_batch["dones"][-1]
|
|
if completed:
|
|
last_r = 0.0
|
|
else:
|
|
raise NotImplementedError(
|
|
"last done mask in a batch should be True. "
|
|
"For now, we only support reading experience batches produced "
|
|
"with batch_mode='complete_episodes'.",
|
|
len(sample_batch[SampleBatch.DONES]),
|
|
sample_batch[SampleBatch.DONES][-1])
|
|
batch = compute_advantages(
|
|
sample_batch, last_r, gamma=self.config["gamma"], use_gae=False)
|
|
return batch
|
|
|
|
|
|
class MARWILPolicy(MARWILPostprocessing, TFPolicy):
|
|
def __init__(self, observation_space, action_space, config):
|
|
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
|
|
self.config = config
|
|
|
|
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
|
action_space, self.config["model"])
|
|
|
|
# Action inputs
|
|
self.obs_t = tf.placeholder(
|
|
tf.float32, shape=(None, ) + observation_space.shape)
|
|
prev_actions_ph = ModelCatalog.get_action_placeholder(action_space)
|
|
prev_rewards_ph = tf.placeholder(
|
|
tf.float32, [None], name="prev_reward")
|
|
|
|
with tf.variable_scope(POLICY_SCOPE) as scope:
|
|
self.model = ModelCatalog.get_model({
|
|
"obs": self.obs_t,
|
|
"prev_actions": prev_actions_ph,
|
|
"prev_rewards": prev_rewards_ph,
|
|
"is_training": self._get_is_training_placeholder(),
|
|
}, observation_space, action_space, logit_dim,
|
|
self.config["model"])
|
|
logits = self.model.outputs
|
|
self.p_func_vars = scope_vars(scope.name)
|
|
|
|
# Action outputs
|
|
action_dist = dist_class(logits, self.model)
|
|
self.output_actions = action_dist.sample()
|
|
|
|
# Training inputs
|
|
self.act_t = ModelCatalog.get_action_placeholder(action_space)
|
|
self.cum_rew_t = tf.placeholder(tf.float32, [None], name="reward")
|
|
|
|
# v network evaluation
|
|
with tf.variable_scope(VALUE_SCOPE) as scope:
|
|
state_values = self.model.value_function()
|
|
self.v_func_vars = scope_vars(scope.name)
|
|
self.v_loss = self._build_value_loss(state_values, self.cum_rew_t)
|
|
self.p_loss = self._build_policy_loss(state_values, self.cum_rew_t,
|
|
logits, self.act_t, action_space)
|
|
|
|
# which kind of objective to optimize
|
|
objective = (
|
|
self.p_loss.loss + self.config["vf_coeff"] * self.v_loss.loss)
|
|
self.explained_variance = tf.reduce_mean(
|
|
explained_variance(self.cum_rew_t, state_values))
|
|
|
|
# initialize TFPolicy
|
|
self.sess = tf.get_default_session()
|
|
self.loss_inputs = [
|
|
(SampleBatch.CUR_OBS, self.obs_t),
|
|
(SampleBatch.ACTIONS, self.act_t),
|
|
(Postprocessing.ADVANTAGES, self.cum_rew_t),
|
|
]
|
|
TFPolicy.__init__(
|
|
self,
|
|
observation_space,
|
|
action_space,
|
|
self.sess,
|
|
obs_input=self.obs_t,
|
|
action_sampler=self.output_actions,
|
|
action_logp=action_dist.sampled_action_logp(),
|
|
loss=objective,
|
|
model=self.model,
|
|
loss_inputs=self.loss_inputs,
|
|
state_inputs=self.model.state_in,
|
|
state_outputs=self.model.state_out,
|
|
prev_action_input=prev_actions_ph,
|
|
prev_reward_input=prev_rewards_ph)
|
|
self.sess.run(tf.global_variables_initializer())
|
|
|
|
self.stats_fetches = {
|
|
"total_loss": objective,
|
|
"vf_explained_var": self.explained_variance,
|
|
"policy_loss": self.p_loss.loss,
|
|
"vf_loss": self.v_loss.loss
|
|
}
|
|
|
|
def _build_value_loss(self, state_values, cum_rwds):
|
|
return ValueLoss(state_values, cum_rwds)
|
|
|
|
def _build_policy_loss(self, state_values, cum_rwds, logits, actions,
|
|
action_space):
|
|
return ReweightedImitationLoss(state_values, cum_rwds, logits, actions,
|
|
action_space, self.config["beta"],
|
|
self.model)
|
|
|
|
@override(TFPolicy)
|
|
def extra_compute_grad_fetches(self):
|
|
return {LEARNER_STATS_KEY: self.stats_fetches}
|
|
|
|
@override(Policy)
|
|
def get_initial_state(self):
|
|
return self.model.state_init
|