2021-01-19 09:51:05 +01:00
|
|
|
import logging
|
2022-05-20 05:10:59 -07:00
|
|
|
from typing import Any, Dict, List, Optional, Type, Union
|
2021-01-19 09:51:05 +01:00
|
|
|
|
2019-01-17 11:00:43 +08:00
|
|
|
import ray
|
2022-05-20 05:10:59 -07:00
|
|
|
from ray.rllib.evaluation.episode import Episode
|
2020-01-31 20:57:52 +00:00
|
|
|
from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
|
2021-05-19 07:32:29 -07:00
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2022-05-20 05:10:59 -07:00
|
|
|
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
|
|
|
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
|
|
|
|
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
|
|
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2022-05-25 05:38:03 -07:00
|
|
|
from ray.rllib.policy.tf_mixins import (
|
|
|
|
ValueNetworkMixin,
|
|
|
|
compute_gradients,
|
|
|
|
)
|
2022-05-20 05:10:59 -07:00
|
|
|
from ray.rllib.utils.annotations import override
|
|
|
|
from ray.rllib.utils.framework import try_import_tf, get_variable
|
|
|
|
from ray.rllib.utils.tf_utils import explained_variance
|
2022-05-25 05:38:03 -07:00
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
LocalOptimizer,
|
|
|
|
ModelGradients,
|
|
|
|
TensorType,
|
|
|
|
)
|
2019-05-10 20:36:18 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2019-01-17 11:00:43 +08:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2019-01-17 11:00:43 +08:00
|
|
|
|
2022-05-20 05:10:59 -07:00
|
|
|
class PostprocessAdvantages:
|
|
|
|
"""Marwil's custom trajectory post-processing mixin."""
|
2020-12-27 09:46:03 -05:00
|
|
|
|
2022-05-20 05:10:59 -07:00
|
|
|
def __init__(self):
|
|
|
|
pass
|
2020-12-27 09:46:03 -05:00
|
|
|
|
2022-05-20 05:10:59 -07:00
|
|
|
def postprocess_trajectory(
|
|
|
|
self,
|
|
|
|
sample_batch: SampleBatch,
|
|
|
|
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
|
|
|
|
episode: Optional["Episode"] = None,
|
|
|
|
):
|
|
|
|
sample_batch = super().postprocess_trajectory(
|
|
|
|
sample_batch, other_agent_batches, episode
|
2021-04-27 10:44:54 +02:00
|
|
|
)
|
2020-12-27 09:46:03 -05:00
|
|
|
|
2022-05-20 05:10:59 -07:00
|
|
|
# Trajectory is actually complete -> last r=0.0.
|
|
|
|
if sample_batch[SampleBatch.DONES][-1]:
|
|
|
|
last_r = 0.0
|
|
|
|
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
|
|
|
else:
|
|
|
|
# Input dict is provided to us automatically via the Model's
|
|
|
|
# requirements. It's a single-timestep (last one in trajectory)
|
|
|
|
# input_dict.
|
|
|
|
# Create an input dict according to the Model's requirements.
|
|
|
|
index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
|
|
|
|
input_dict = sample_batch.get_single_step_input_dict(
|
|
|
|
self.model.view_requirements, index=index
|
|
|
|
)
|
|
|
|
last_r = self._value(**input_dict)
|
|
|
|
|
|
|
|
# Adds the "advantages" (which in the case of MARWIL are simply the
|
2022-05-21 03:50:20 -07:00
|
|
|
# discounted cumulative rewards) to the SampleBatch.
|
2022-05-20 05:10:59 -07:00
|
|
|
return compute_advantages(
|
|
|
|
sample_batch,
|
|
|
|
last_r,
|
|
|
|
self.config["gamma"],
|
2022-05-21 03:50:20 -07:00
|
|
|
# We just want the discounted cumulative rewards, so we won't need
|
2022-05-20 05:10:59 -07:00
|
|
|
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
|
|
|
|
use_gae=False,
|
|
|
|
use_critic=False,
|
|
|
|
)
|
2020-01-31 20:57:52 +00:00
|
|
|
|
|
|
|
|
2020-04-07 01:38:50 +02:00
|
|
|
class MARWILLoss:
|
2021-05-19 07:32:29 -07:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
policy: Policy,
|
|
|
|
value_estimates: TensorType,
|
2021-05-25 08:47:17 +02:00
|
|
|
action_dist: ActionDistribution,
|
|
|
|
train_batch: SampleBatch,
|
|
|
|
vf_loss_coeff: float,
|
|
|
|
beta: float,
|
|
|
|
):
|
|
|
|
# L = - A * log\pi_\theta(a|s)
|
|
|
|
logprobs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
|
2021-01-19 09:51:05 +01:00
|
|
|
if beta != 0.0:
|
2021-05-25 08:47:17 +02:00
|
|
|
cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
|
|
|
|
# Advantage Estimation.
|
|
|
|
adv = cumulative_rewards - value_estimates
|
|
|
|
adv_squared = tf.reduce_mean(tf.math.square(adv))
|
|
|
|
# Value function's loss term (MSE).
|
|
|
|
self.v_loss = 0.5 * adv_squared
|
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
# Perform moving averaging of advantage^2.
|
2021-06-03 22:29:00 +02:00
|
|
|
rate = policy.config["moving_average_sqd_adv_norm_update_rate"]
|
2020-01-31 20:57:52 +00:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
# Update averaged advantage norm.
|
|
|
|
# Eager.
|
|
|
|
if policy.config["framework"] in ["tf2", "tfe"]:
|
|
|
|
update_term = adv_squared - policy._moving_average_sqd_adv_norm
|
|
|
|
policy._moving_average_sqd_adv_norm.assign_add(rate * update_term)
|
2020-01-31 20:57:52 +00:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
# Exponentially weighted advantages.
|
|
|
|
c = tf.math.sqrt(policy._moving_average_sqd_adv_norm)
|
2021-01-22 19:36:02 +01:00
|
|
|
exp_advs = tf.math.exp(beta * (adv / (1e-8 + c)))
|
2021-01-19 09:51:05 +01:00
|
|
|
# Static graph.
|
|
|
|
else:
|
|
|
|
update_adv_norm = tf1.assign_add(
|
|
|
|
ref=policy._moving_average_sqd_adv_norm,
|
|
|
|
value=rate * (adv_squared - policy._moving_average_sqd_adv_norm),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Exponentially weighted advantages.
|
|
|
|
with tf1.control_dependencies([update_adv_norm]):
|
|
|
|
exp_advs = tf.math.exp(
|
|
|
|
beta
|
|
|
|
* tf.math.divide(
|
|
|
|
adv,
|
|
|
|
1e-8 + tf.math.sqrt(policy._moving_average_sqd_adv_norm),
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-01-19 09:51:05 +01:00
|
|
|
exp_advs = tf.stop_gradient(exp_advs)
|
2021-05-25 08:47:17 +02:00
|
|
|
|
|
|
|
self.explained_variance = tf.reduce_mean(
|
|
|
|
explained_variance(cumulative_rewards, value_estimates)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-05-25 08:47:17 +02:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
else:
|
2021-05-25 08:47:17 +02:00
|
|
|
# Value function's loss term (MSE).
|
|
|
|
self.v_loss = tf.constant(0.0)
|
2021-01-19 09:51:05 +01:00
|
|
|
exp_advs = 1.0
|
|
|
|
|
2022-01-26 07:00:17 -08:00
|
|
|
# logprob loss alone tends to push action distributions to
|
|
|
|
# have very low entropy, resulting in worse performance for
|
|
|
|
# unfamiliar situations.
|
|
|
|
# A scaled logstd loss term encourages stochasticity, thus
|
|
|
|
# alleviate the problem to some extent.
|
|
|
|
logstd_coeff = policy.config["bc_logstd_coeff"]
|
|
|
|
if logstd_coeff > 0.0:
|
|
|
|
logstds = tf.reduce_sum(action_dist.log_std, axis=1)
|
|
|
|
else:
|
|
|
|
logstds = 0.0
|
|
|
|
|
|
|
|
self.p_loss = -1.0 * tf.reduce_mean(
|
|
|
|
exp_advs * (logprobs + logstd_coeff * logstds)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-01-17 11:00:43 +08:00
|
|
|
|
2021-01-19 09:51:05 +01:00
|
|
|
self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss
|
2019-01-17 11:00:43 +08:00
|
|
|
|
2020-01-31 20:57:52 +00:00
|
|
|
|
2022-05-20 05:10:59 -07:00
|
|
|
# We need this builder function because we want to share the same
|
|
|
|
# custom logics between TF1 dynamic and TF2 eager policies.
|
|
|
|
def get_marwil_tf_policy(base: type) -> type:
|
|
|
|
"""Construct a MARWILTFPolicy inheriting either dynamic or eager base policies.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
|
|
|
|
|
|
|
|
Returns:
|
2022-06-04 07:35:24 +02:00
|
|
|
A TF Policy to be used with MAML.
|
2022-05-20 05:10:59 -07:00
|
|
|
"""
|
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
class MARWILTFPolicy(ValueNetworkMixin, PostprocessAdvantages, base):
|
2022-05-20 05:10:59 -07:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
|
|
|
config,
|
|
|
|
existing_model=None,
|
|
|
|
existing_inputs=None,
|
|
|
|
):
|
|
|
|
# First thing first, enable eager execution if necessary.
|
|
|
|
base.enable_eager_execution_if_necessary()
|
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
config = dict(
|
|
|
|
ray.rllib.algorithms.marwil.marwil.MARWILConfig().to_dict(), **config
|
|
|
|
)
|
2022-05-20 05:10:59 -07:00
|
|
|
|
|
|
|
# Initialize base class.
|
|
|
|
base.__init__(
|
|
|
|
self,
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
|
|
|
config,
|
|
|
|
existing_inputs=existing_inputs,
|
|
|
|
existing_model=existing_model,
|
|
|
|
)
|
|
|
|
|
|
|
|
ValueNetworkMixin.__init__(self, config)
|
|
|
|
PostprocessAdvantages.__init__(self)
|
|
|
|
|
|
|
|
# Not needed for pure BC.
|
|
|
|
if config["beta"] != 0.0:
|
|
|
|
# Set up a tf-var for the moving avg (do this here to make it work
|
|
|
|
# with eager mode); "c^2" in the paper.
|
|
|
|
self._moving_average_sqd_adv_norm = get_variable(
|
|
|
|
config["moving_average_sqd_adv_norm_start"],
|
|
|
|
framework="tf",
|
|
|
|
tf_name="moving_average_of_advantage_norm",
|
|
|
|
trainable=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Note: this is a bit ugly, but loss and optimizer initialization must
|
|
|
|
# happen after all the MixIns are initialized.
|
|
|
|
self.maybe_initialize_optimizer_and_loss()
|
|
|
|
|
|
|
|
@override(base)
|
|
|
|
def loss(
|
|
|
|
self,
|
|
|
|
model: Union[ModelV2, "tf.keras.Model"],
|
|
|
|
dist_class: Type[TFActionDistribution],
|
|
|
|
train_batch: SampleBatch,
|
|
|
|
) -> Union[TensorType, List[TensorType]]:
|
|
|
|
model_out, _ = model(train_batch)
|
|
|
|
action_dist = dist_class(model_out, model)
|
|
|
|
value_estimates = model.value_function()
|
|
|
|
|
2022-06-03 01:50:36 -07:00
|
|
|
self._marwil_loss = MARWILLoss(
|
2022-05-20 05:10:59 -07:00
|
|
|
self,
|
|
|
|
value_estimates,
|
|
|
|
action_dist,
|
|
|
|
train_batch,
|
|
|
|
self.config["vf_coeff"],
|
|
|
|
self.config["beta"],
|
|
|
|
)
|
|
|
|
|
2022-06-03 01:50:36 -07:00
|
|
|
return self._marwil_loss.total_loss
|
2022-05-20 05:10:59 -07:00
|
|
|
|
|
|
|
@override(base)
|
|
|
|
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
|
|
|
stats = {
|
2022-06-03 01:50:36 -07:00
|
|
|
"policy_loss": self._marwil_loss.p_loss,
|
|
|
|
"total_loss": self._marwil_loss.total_loss,
|
2022-05-20 05:10:59 -07:00
|
|
|
}
|
|
|
|
if self.config["beta"] != 0.0:
|
|
|
|
stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm
|
2022-06-03 01:50:36 -07:00
|
|
|
stats["vf_explained_var"] = self._marwil_loss.explained_variance
|
|
|
|
stats["vf_loss"] = self._marwil_loss.v_loss
|
2022-05-20 05:10:59 -07:00
|
|
|
|
|
|
|
return stats
|
|
|
|
|
2022-05-25 05:38:03 -07:00
|
|
|
@override(base)
|
|
|
|
def compute_gradients_fn(
|
|
|
|
self, optimizer: LocalOptimizer, loss: TensorType
|
|
|
|
) -> ModelGradients:
|
|
|
|
return compute_gradients(self, optimizer, loss)
|
|
|
|
|
2022-05-20 05:10:59 -07:00
|
|
|
return MARWILTFPolicy
|
2019-01-17 11:00:43 +08:00
|
|
|
|
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
MARWILTF1Policy = get_marwil_tf_policy(DynamicTFPolicyV2)
|
|
|
|
MARWILTF2Policy = get_marwil_tf_policy(EagerTFPolicyV2)
|