2021-01-19 09:51:05 +01:00
|
|
|
import logging
|
2021-05-19 07:32:29 -07:00
|
|
|
import gym
|
|
|
|
from typing import Optional, Dict
|
2021-01-19 09:51:05 +01:00
|
|
|
|
2019-01-17 11:00:43 +08:00
|
|
|
import ray
|
2021-01-22 19:36:02 +01:00
|
|
|
from ray.rllib.agents.ppo.ppo_tf_policy import compute_and_clip_gradients
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2020-01-31 20:57:52 +00:00
|
|
|
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
|
|
|
Postprocessing
|
|
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
2020-07-11 22:06:35 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, get_variable
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
2021-05-19 07:32:29 -07:00
|
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
|
|
|
|
PolicyID
|
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2019-05-10 20:36:18 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2019-01-17 11:00:43 +08:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2019-01-17 11:00:43 +08:00
|
|
|
|
2020-04-07 01:38:50 +02:00
|
|
|
class ValueNetworkMixin:
|
2021-05-19 07:32:29 -07:00
|
|
|
def __init__(self, obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space, config: TrainerConfigDict):
|
2020-12-27 09:46:03 -05:00
|
|
|
|
|
|
|
# Input dict is provided to us automatically via the Model's
|
|
|
|
# requirements. It's a single-timestep (last one in trajectory)
|
|
|
|
# input_dict.
|
2021-03-23 17:50:18 +01:00
|
|
|
@make_tf_callable(self.get_session())
|
|
|
|
def value(**input_dict):
|
|
|
|
model_out, _ = self.model.from_batch(input_dict, is_training=False)
|
|
|
|
# [0] = remove the batch dim.
|
|
|
|
return self.model.value_function()[0]
|
2019-01-17 11:00:43 +08:00
|
|
|
|
2020-01-31 20:57:52 +00:00
|
|
|
self._value = value
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def postprocess_advantages(
|
|
|
|
policy: Policy,
|
|
|
|
sample_batch: SampleBatch,
|
|
|
|
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
|
|
|
|
episode=None) -> SampleBatch:
|
2020-12-27 09:46:03 -05:00
|
|
|
"""Postprocesses a trajectory and returns the processed trajectory.
|
|
|
|
|
|
|
|
The trajectory contains only data from one episode and from one agent.
|
|
|
|
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
|
|
|
contain a truncated (at-the-end) episode, in case the
|
|
|
|
`config.rollout_fragment_length` was reached by the sampler.
|
|
|
|
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
|
|
|
exactly one episode (no matter how long).
|
|
|
|
New columns can be added to sample_batch and existing ones may be altered.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy used to generate the trajectory
|
|
|
|
(`sample_batch`)
|
|
|
|
sample_batch (SampleBatch): The SampleBatch to postprocess.
|
|
|
|
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
|
|
|
dict of AgentIDs mapping to other agents' trajectory data (from the
|
|
|
|
same episode). NOTE: The other agents use the same policy.
|
|
|
|
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
|
|
|
object in which the agents operated.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Trajectory is actually complete -> last r=0.0.
|
|
|
|
if sample_batch[SampleBatch.DONES][-1]:
|
2020-01-31 20:57:52 +00:00
|
|
|
last_r = 0.0
|
2020-12-27 09:46:03 -05:00
|
|
|
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
2020-01-31 20:57:52 +00:00
|
|
|
else:
|
2020-12-27 09:46:03 -05:00
|
|
|
# Input dict is provided to us automatically via the Model's
|
|
|
|
# requirements. It's a single-timestep (last one in trajectory)
|
|
|
|
# input_dict.
|
2021-03-23 17:50:18 +01:00
|
|
|
# Create an input dict according to the Model's requirements.
|
2021-04-15 19:19:51 +02:00
|
|
|
index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
|
2021-04-27 10:44:54 +02:00
|
|
|
input_dict = sample_batch.get_single_step_input_dict(
|
|
|
|
policy.model.view_requirements, index=index)
|
2021-03-23 17:50:18 +01:00
|
|
|
last_r = policy._value(**input_dict)
|
2020-12-27 09:46:03 -05:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
# Adds the "advantages" (which in the case of MARWIL are simply the
|
|
|
|
# discounted cummulative rewards) to the SampleBatch.
|
2020-01-31 20:57:52 +00:00
|
|
|
return compute_advantages(
|
2020-02-01 08:25:45 +02:00
|
|
|
sample_batch,
|
|
|
|
last_r,
|
|
|
|
policy.config["gamma"],
|
2021-01-19 09:51:05 +01:00
|
|
|
# We just want the discounted cummulative rewards, so we won't need
|
|
|
|
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
|
2020-02-01 08:25:45 +02:00
|
|
|
use_gae=False,
|
|
|
|
use_critic=False)
|
2020-01-31 20:57:52 +00:00
|
|
|
|
|
|
|
|
2020-04-07 01:38:50 +02:00
|
|
|
class MARWILLoss:
|
2021-05-19 07:32:29 -07:00
|
|
|
def __init__(self, policy: Policy, value_estimates: TensorType,
|
|
|
|
action_dist: ActionDistribution, actions: TensorType,
|
|
|
|
cumulative_rewards: TensorType, vf_loss_coeff: float,
|
|
|
|
beta: float):
|
2021-01-19 09:51:05 +01:00
|
|
|
|
|
|
|
# Advantage Estimation.
|
|
|
|
adv = cumulative_rewards - value_estimates
|
|
|
|
adv_squared = tf.reduce_mean(tf.math.square(adv))
|
|
|
|
|
|
|
|
# Value function's loss term (MSE).
|
|
|
|
self.v_loss = 0.5 * adv_squared
|
|
|
|
|
|
|
|
if beta != 0.0:
|
|
|
|
# Perform moving averaging of advantage^2.
|
2020-01-31 20:57:52 +00:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
# Update averaged advantage norm.
|
|
|
|
# Eager.
|
|
|
|
if policy.config["framework"] in ["tf2", "tfe"]:
|
|
|
|
update_term = adv_squared - policy._moving_average_sqd_adv_norm
|
|
|
|
policy._moving_average_sqd_adv_norm.assign_add(
|
2021-05-20 00:44:11 +02:00
|
|
|
1e-7 * update_term)
|
2020-01-31 20:57:52 +00:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
# Exponentially weighted advantages.
|
|
|
|
c = tf.math.sqrt(policy._moving_average_sqd_adv_norm)
|
2021-01-22 19:36:02 +01:00
|
|
|
exp_advs = tf.math.exp(beta * (adv / (1e-8 + c)))
|
2021-01-19 09:51:05 +01:00
|
|
|
# Static graph.
|
|
|
|
else:
|
|
|
|
update_adv_norm = tf1.assign_add(
|
|
|
|
ref=policy._moving_average_sqd_adv_norm,
|
2021-05-20 00:44:11 +02:00
|
|
|
value=1e-7 *
|
2021-01-19 09:51:05 +01:00
|
|
|
(adv_squared - policy._moving_average_sqd_adv_norm))
|
|
|
|
|
|
|
|
# Exponentially weighted advantages.
|
|
|
|
with tf1.control_dependencies([update_adv_norm]):
|
|
|
|
exp_advs = tf.math.exp(beta * tf.math.divide(
|
|
|
|
adv, 1e-8 + tf.math.sqrt(
|
|
|
|
policy._moving_average_sqd_adv_norm)))
|
|
|
|
exp_advs = tf.stop_gradient(exp_advs)
|
|
|
|
else:
|
|
|
|
exp_advs = 1.0
|
|
|
|
|
|
|
|
# L = - A * log\pi_\theta(a|s)
|
|
|
|
logprobs = action_dist.logp(actions)
|
|
|
|
self.p_loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)
|
2019-01-17 11:00:43 +08:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss
|
2019-01-17 11:00:43 +08:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
self.explained_variance = tf.reduce_mean(
|
|
|
|
explained_variance(cumulative_rewards, value_estimates))
|
2020-01-31 20:57:52 +00:00
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
|
|
|
|
train_batch: SampleBatch) -> TensorType:
|
2020-01-31 20:57:52 +00:00
|
|
|
model_out, _ = model.from_batch(train_batch)
|
|
|
|
action_dist = dist_class(model_out, model)
|
2021-01-19 09:51:05 +01:00
|
|
|
value_estimates = model.value_function()
|
2020-01-31 20:57:52 +00:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
policy.loss = MARWILLoss(policy, value_estimates, action_dist,
|
2020-01-31 20:57:52 +00:00
|
|
|
train_batch[SampleBatch.ACTIONS],
|
|
|
|
train_batch[Postprocessing.ADVANTAGES],
|
|
|
|
policy.config["vf_coeff"], policy.config["beta"])
|
|
|
|
|
|
|
|
return policy.loss.total_loss
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
2020-01-31 20:57:52 +00:00
|
|
|
return {
|
2021-01-19 09:51:05 +01:00
|
|
|
"policy_loss": policy.loss.p_loss,
|
|
|
|
"vf_loss": policy.loss.v_loss,
|
2020-01-31 20:57:52 +00:00
|
|
|
"total_loss": policy.loss.total_loss,
|
|
|
|
"vf_explained_var": policy.loss.explained_variance,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict) -> None:
|
2020-12-27 09:46:03 -05:00
|
|
|
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
2020-04-07 01:38:50 +02:00
|
|
|
# Set up a tf-var for the moving avg (do this here to make it work with
|
2021-01-19 09:51:05 +01:00
|
|
|
# eager mode); "c^2" in the paper.
|
|
|
|
policy._moving_average_sqd_adv_norm = get_variable(
|
2020-07-11 22:06:35 +02:00
|
|
|
100.0,
|
|
|
|
framework="tf",
|
|
|
|
tf_name="moving_average_of_advantage_norm",
|
2020-04-07 01:38:50 +02:00
|
|
|
trainable=False)
|
2019-01-17 11:00:43 +08:00
|
|
|
|
|
|
|
|
2020-01-31 20:57:52 +00:00
|
|
|
MARWILTFPolicy = build_tf_policy(
|
|
|
|
name="MARWILTFPolicy",
|
|
|
|
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
|
|
|
|
loss_fn=marwil_loss,
|
|
|
|
stats_fn=stats,
|
|
|
|
postprocess_fn=postprocess_advantages,
|
|
|
|
before_loss_init=setup_mixins,
|
2021-05-18 11:10:46 +02:00
|
|
|
compute_gradients_fn=compute_and_clip_gradients,
|
2020-01-31 20:57:52 +00:00
|
|
|
mixins=[ValueNetworkMixin])
|