From d5a6d46049d0ea0490c90366a081de79a87d0fac Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Fri, 20 May 2022 05:10:59 -0700 Subject: [PATCH] [RLlib] Migrate MAML, MB-MPO, MARWIL, and BC to use Policy sub-classing implementation. (#24914) --- rllib/agents/marwil/__init__.py | 8 +- rllib/agents/ppo/ppo_tf_policy.py | 4 +- rllib/agents/ppo/ppo_torch_policy.py | 4 +- rllib/algorithms/maml/maml.py | 12 +- rllib/algorithms/maml/maml_tf_policy.py | 274 ++++++++++------- rllib/algorithms/maml/maml_torch_policy.py | 262 +++++++++------- rllib/algorithms/marwil/__init__.py | 8 +- rllib/algorithms/marwil/marwil.py | 11 +- rllib/algorithms/marwil/marwil_tf_policy.py | 286 +++++++++--------- .../algorithms/marwil/marwil_torch_policy.py | 194 ++++++------ rllib/algorithms/marwil/tests/test_marwil.py | 6 +- rllib/algorithms/mbmpo/mbmpo.py | 3 +- rllib/algorithms/mbmpo/mbmpo_torch_policy.py | 181 +++++------ rllib/policy/policy.py | 34 ++- rllib/policy/tf_policy.py | 14 +- rllib/policy/torch_policy_v2.py | 44 +-- 16 files changed, 719 insertions(+), 626 deletions(-) diff --git a/rllib/agents/marwil/__init__.py b/rllib/agents/marwil/__init__.py index abd269487..0d87e76bd 100644 --- a/rllib/agents/marwil/__init__.py +++ b/rllib/agents/marwil/__init__.py @@ -1,13 +1,17 @@ from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG -from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy +from ray.rllib.algorithms.marwil.marwil_tf_policy import ( + MARWILDynamicTFPolicy, + MARWILEagerTFPolicy, +) from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy __all__ = [ "BCTrainer", "BC_DEFAULT_CONFIG", "DEFAULT_CONFIG", - "MARWILTFPolicy", + "MARWILDynamicTFPolicy", + "MARWILEagerTFPolicy", "MARWILTorchPolicy", "MARWILTrainer", ] diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index e2a70bd73..462fb5794 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -241,7 +241,7 @@ def compute_and_clip_gradients( return grads_and_vars -def setup_config( +def validate_config( policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, @@ -324,7 +324,7 @@ PPOTFPolicy = build_tf_policy( stats_fn=kl_and_loss_stats, compute_gradients_fn=compute_and_clip_gradients, extra_action_out_fn=vf_preds_fetches, - before_init=setup_config, + before_init=validate_config, before_loss_init=setup_mixins, mixins=[ LearningRateSchedule, diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index f5abc5db1..0e3a0d54c 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -2,7 +2,7 @@ import logging from typing import Dict, List, Type, Union import ray -from ray.rllib.agents.ppo.ppo_tf_policy import setup_config +from ray.rllib.agents.ppo.ppo_tf_policy import validate_config from ray.rllib.evaluation.postprocessing import ( compute_gae_for_sample_batch, Postprocessing, @@ -43,7 +43,7 @@ class PPOTorchPolicy( def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config) - setup_config(self, observation_space, action_space, config) + validate_config(self, observation_space, action_space, config) TorchPolicy.__init__( self, diff --git a/rllib/algorithms/maml/maml.py b/rllib/algorithms/maml/maml.py index a05284318..1c2e67b75 100644 --- a/rllib/algorithms/maml/maml.py +++ b/rllib/algorithms/maml/maml.py @@ -4,8 +4,6 @@ from typing import Type from ray.rllib.utils.sgd import standardized from ray.rllib.agents import with_common_config -from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTFPolicy -from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy from ray.rllib.agents.trainer import Trainer from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.worker_set import WorkerSet @@ -199,9 +197,17 @@ class MAMLTrainer(Trainer): @override(Trainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: if config["framework"] == "torch": + from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy + return MAMLTorchPolicy + elif config["framework"] == "tf": + from ray.rllib.algorithms.maml.maml_tf_policy import MAMLDynamicTFPolicy + + return MAMLDynamicTFPolicy else: - return MAMLTFPolicy + from ray.rllib.algorithms.maml.maml_tf_policy import MAMLEagerTFPolicy + + return MAMLEagerTFPolicy @staticmethod @override(Trainer) diff --git a/rllib/algorithms/maml/maml_tf_policy.py b/rllib/algorithms/maml/maml_tf_policy.py index c3d99dc4a..05b1da704 100644 --- a/rllib/algorithms/maml/maml_tf_policy.py +++ b/rllib/algorithms/maml/maml_tf_policy.py @@ -1,20 +1,23 @@ import logging +from typing import Dict, List, Type, Union import ray -from ray.rllib.agents.ppo.ppo_tf_policy import ( - vf_preds_fetches, - compute_and_clip_gradients, - setup_config, -) -from ray.rllib.evaluation.postprocessing import ( - compute_gae_for_sample_batch, - Postprocessing, -) +from ray.rllib.agents.ppo.ppo_tf_policy import validate_config +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.models.utils import get_activation_fn +from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.policy.tf_mixins import ValueNetworkMixin +from ray.rllib.policy.tf_mixins import ( + ComputeAndClipGradsMixIn, + ComputeGAEMixIn, + ValueNetworkMixin, +) from ray.rllib.utils import try_import_tf +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() @@ -326,72 +329,6 @@ class MAMLLoss(object): return placeholder_list -def maml_loss(policy, model, dist_class, train_batch): - logits, state = model(train_batch) - policy.cur_lr = policy.config["lr"] - - if policy.config["worker_index"]: - policy.loss_obj = WorkerLoss( - dist_class=dist_class, - actions=train_batch[SampleBatch.ACTIONS], - curr_logits=logits, - behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], - advantages=train_batch[Postprocessing.ADVANTAGES], - value_fn=model.value_function(), - value_targets=train_batch[Postprocessing.VALUE_TARGETS], - vf_preds=train_batch[SampleBatch.VF_PREDS], - cur_kl_coeff=0.0, - entropy_coeff=policy.config["entropy_coeff"], - clip_param=policy.config["clip_param"], - vf_clip_param=policy.config["vf_clip_param"], - vf_loss_coeff=policy.config["vf_loss_coeff"], - clip_loss=False, - ) - else: - policy.var_list = tf1.get_collection( - tf1.GraphKeys.TRAINABLE_VARIABLES, tf1.get_variable_scope().name - ) - policy.loss_obj = MAMLLoss( - model=model, - dist_class=dist_class, - value_targets=train_batch[Postprocessing.VALUE_TARGETS], - advantages=train_batch[Postprocessing.ADVANTAGES], - actions=train_batch[SampleBatch.ACTIONS], - behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], - vf_preds=train_batch[SampleBatch.VF_PREDS], - cur_kl_coeff=policy.kl_coeff, - policy_vars=policy.var_list, - obs=train_batch[SampleBatch.CUR_OBS], - num_tasks=policy.config["num_workers"], - split=train_batch["split"], - config=policy.config, - inner_adaptation_steps=policy.config["inner_adaptation_steps"], - entropy_coeff=policy.config["entropy_coeff"], - clip_param=policy.config["clip_param"], - vf_clip_param=policy.config["vf_clip_param"], - vf_loss_coeff=policy.config["vf_loss_coeff"], - use_gae=policy.config["use_gae"], - ) - - return policy.loss_obj.loss - - -def maml_stats(policy, train_batch): - if policy.config["worker_index"]: - return {"worker_loss": policy.loss_obj.loss} - else: - return { - "cur_kl_coeff": tf.cast(policy.kl_coeff, tf.float64), - "cur_lr": tf.cast(policy.cur_lr, tf.float64), - "total_loss": policy.loss_obj.loss, - "policy_loss": policy.loss_obj.mean_policy_loss, - "vf_loss": policy.loss_obj.mean_vf_loss, - "kl": policy.loss_obj.mean_kl, - "inner_kl": policy.loss_obj.mean_inner_kl, - "entropy": policy.loss_obj.mean_entropy, - } - - class KLCoeffMixin: def __init__(self, config): self.kl_coeff_val = [config["kl_coeff"]] * config["inner_adaptation_steps"] @@ -415,41 +352,154 @@ class KLCoeffMixin: return self.kl_coeff_val -def maml_optimizer_fn(policy, config): +# We need this builder function because we want to share the same +# custom logics between TF1 dynamic and TF2 eager policies. +def get_maml_tf_policy(base: type) -> type: + """Construct a MAMLTFPolicy inheriting either dynamic or eager base policies. + + Args: + base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. + + Returns: + A TF Policy to be used with MAMLTrainer. """ - Workers use simple SGD for inner adaptation - Meta-Policy uses Adam optimizer for meta-update - """ - if not config["worker_index"]: - return tf1.train.AdamOptimizer(learning_rate=config["lr"]) - return tf1.train.GradientDescentOptimizer(learning_rate=config["inner_lr"]) + + class MAMLTFPolicy( + ComputeGAEMixIn, ComputeAndClipGradsMixIn, KLCoeffMixin, ValueNetworkMixin, base + ): + def __init__( + self, + obs_space, + action_space, + config, + existing_model=None, + existing_inputs=None, + ): + # First thing first, enable eager execution if necessary. + base.enable_eager_execution_if_necessary() + + config = dict(ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, **config) + validate_config(self, obs_space, action_space, config) + + # Initialize base class. + base.__init__( + self, + obs_space, + action_space, + config, + existing_inputs=existing_inputs, + existing_model=existing_model, + ) + + ComputeGAEMixIn.__init__(self) + ComputeAndClipGradsMixIn.__init__(self) + KLCoeffMixin.__init__(self, config) + ValueNetworkMixin.__init__(self, config) + + # Create the `split` placeholder before initialize loss. + if self.framework == "tf": + self._loss_input_dict["split"] = tf1.placeholder( + tf.int32, + name="Meta-Update-Splitting", + shape=( + self.config["inner_adaptation_steps"] + 1, + self.config["num_workers"], + ), + ) + + # Note: this is a bit ugly, but loss and optimizer initialization must + # happen after all the MixIns are initialized. + self.maybe_initialize_optimizer_and_loss() + + @override(base) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + logits, state = model(train_batch) + self.cur_lr = self.config["lr"] + + if self.config["worker_index"]: + self.loss_obj = WorkerLoss( + dist_class=dist_class, + actions=train_batch[SampleBatch.ACTIONS], + curr_logits=logits, + behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], + advantages=train_batch[Postprocessing.ADVANTAGES], + value_fn=model.value_function(), + value_targets=train_batch[Postprocessing.VALUE_TARGETS], + vf_preds=train_batch[SampleBatch.VF_PREDS], + cur_kl_coeff=0.0, + entropy_coeff=self.config["entropy_coeff"], + clip_param=self.config["clip_param"], + vf_clip_param=self.config["vf_clip_param"], + vf_loss_coeff=self.config["vf_loss_coeff"], + clip_loss=False, + ) + else: + self.var_list = tf1.get_collection( + tf1.GraphKeys.TRAINABLE_VARIABLES, tf1.get_variable_scope().name + ) + self.loss_obj = MAMLLoss( + model=model, + dist_class=dist_class, + value_targets=train_batch[Postprocessing.VALUE_TARGETS], + advantages=train_batch[Postprocessing.ADVANTAGES], + actions=train_batch[SampleBatch.ACTIONS], + behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], + vf_preds=train_batch[SampleBatch.VF_PREDS], + cur_kl_coeff=self.kl_coeff, + policy_vars=self.var_list, + obs=train_batch[SampleBatch.CUR_OBS], + num_tasks=self.config["num_workers"], + split=train_batch["split"], + config=self.config, + inner_adaptation_steps=self.config["inner_adaptation_steps"], + entropy_coeff=self.config["entropy_coeff"], + clip_param=self.config["clip_param"], + vf_clip_param=self.config["vf_clip_param"], + vf_loss_coeff=self.config["vf_loss_coeff"], + use_gae=self.config["use_gae"], + ) + + return self.loss_obj.loss + + @override(base) + def optimizer( + self, + ) -> Union[ + "tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"] + ]: + """ + Workers use simple SGD for inner adaptation + Meta-Policy uses Adam optimizer for meta-update + """ + if not self.config["worker_index"]: + return tf1.train.AdamOptimizer(learning_rate=self.config["lr"]) + return tf1.train.GradientDescentOptimizer( + learning_rate=self.config["inner_lr"] + ) + + @override(base) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + if self.config["worker_index"]: + return {"worker_loss": self.loss_obj.loss} + else: + return { + "cur_kl_coeff": tf.cast(self.kl_coeff, tf.float64), + "cur_lr": tf.cast(self.cur_lr, tf.float64), + "total_loss": self.loss_obj.loss, + "policy_loss": self.loss_obj.mean_policy_loss, + "vf_loss": self.loss_obj.mean_vf_loss, + "kl": self.loss_obj.mean_kl, + "inner_kl": self.loss_obj.mean_inner_kl, + "entropy": self.loss_obj.mean_entropy, + } + + return MAMLTFPolicy -def setup_mixins(policy, obs_space, action_space, config): - ValueNetworkMixin.__init__(policy, config) - KLCoeffMixin.__init__(policy, config) - - # Create the `split` placeholder. - policy._loss_input_dict["split"] = tf1.placeholder( - tf.int32, - name="Meta-Update-Splitting", - shape=( - policy.config["inner_adaptation_steps"] + 1, - policy.config["num_workers"], - ), - ) - - -MAMLTFPolicy = build_tf_policy( - name="MAMLTFPolicy", - get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, - loss_fn=maml_loss, - stats_fn=maml_stats, - optimizer_fn=maml_optimizer_fn, - extra_action_out_fn=vf_preds_fetches, - postprocess_fn=compute_gae_for_sample_batch, - compute_gradients_fn=compute_and_clip_gradients, - before_init=setup_config, - before_loss_init=setup_mixins, - mixins=[KLCoeffMixin], -) +MAMLDynamicTFPolicy = get_maml_tf_policy(DynamicTFPolicyV2) +MAMLEagerTFPolicy = get_maml_tf_policy(EagerTFPolicyV2) diff --git a/rllib/algorithms/maml/maml_torch_policy.py b/rllib/algorithms/maml/maml_torch_policy.py index 8c4ae55a7..73bef4e90 100644 --- a/rllib/algorithms/maml/maml_torch_policy.py +++ b/rllib/algorithms/maml/maml_torch_policy.py @@ -1,17 +1,20 @@ +import higher import logging +from typing import Dict, List, Type, Union import ray -from ray.rllib.evaluation.postprocessing import ( - compute_gae_for_sample_batch, - Postprocessing, -) -from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.agents.ppo.ppo_tf_policy import validate_config +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_mixins import ValueNetworkMixin -from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches -from ray.rllib.agents.ppo.ppo_tf_policy import setup_config +from ray.rllib.policy.torch_mixins import ComputeGAEMixIn, ValueNetworkMixin +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import apply_grad_clipping from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType torch, nn = try_import_torch() @@ -145,9 +148,6 @@ class MAMLLoss(object): vf_loss_coeff=1.0, use_gae=True, ): - - import higher - self.config = config self.num_tasks = num_tasks self.inner_adaptation_steps = inner_adaptation_steps @@ -266,86 +266,6 @@ class MAMLLoss(object): return placeholder_list -def maml_loss(policy, model, dist_class, train_batch): - logits, state = model(train_batch) - policy.cur_lr = policy.config["lr"] - - if policy.config["worker_index"]: - policy.loss_obj = WorkerLoss( - model=model, - dist_class=dist_class, - actions=train_batch[SampleBatch.ACTIONS], - curr_logits=logits, - behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], - advantages=train_batch[Postprocessing.ADVANTAGES], - value_fn=model.value_function(), - value_targets=train_batch[Postprocessing.VALUE_TARGETS], - vf_preds=train_batch[SampleBatch.VF_PREDS], - cur_kl_coeff=0.0, - entropy_coeff=policy.config["entropy_coeff"], - clip_param=policy.config["clip_param"], - vf_clip_param=policy.config["vf_clip_param"], - vf_loss_coeff=policy.config["vf_loss_coeff"], - clip_loss=False, - ) - else: - policy.var_list = model.named_parameters() - - # `split` may not exist yet (during test-loss call), use a dummy value. - # Cannot use get here due to train_batch being a TrackingDict. - if "split" in train_batch: - split = train_batch["split"] - else: - split_shape = ( - policy.config["inner_adaptation_steps"], - policy.config["num_workers"], - ) - split_const = int( - train_batch["obs"].shape[0] // (split_shape[0] * split_shape[1]) - ) - split = torch.ones(split_shape, dtype=int) * split_const - policy.loss_obj = MAMLLoss( - model=model, - dist_class=dist_class, - value_targets=train_batch[Postprocessing.VALUE_TARGETS], - advantages=train_batch[Postprocessing.ADVANTAGES], - actions=train_batch[SampleBatch.ACTIONS], - behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], - vf_preds=train_batch[SampleBatch.VF_PREDS], - cur_kl_coeff=policy.kl_coeff_val, - policy_vars=policy.var_list, - obs=train_batch[SampleBatch.CUR_OBS], - num_tasks=policy.config["num_workers"], - split=split, - config=policy.config, - inner_adaptation_steps=policy.config["inner_adaptation_steps"], - entropy_coeff=policy.config["entropy_coeff"], - clip_param=policy.config["clip_param"], - vf_clip_param=policy.config["vf_clip_param"], - vf_loss_coeff=policy.config["vf_loss_coeff"], - use_gae=policy.config["use_gae"], - meta_opt=policy.meta_opt, - ) - - return policy.loss_obj.loss - - -def maml_stats(policy, train_batch): - if policy.config["worker_index"]: - return {"worker_loss": policy.loss_obj.loss} - else: - return { - "cur_kl_coeff": policy.kl_coeff_val, - "cur_lr": policy.cur_lr, - "total_loss": policy.loss_obj.loss, - "policy_loss": policy.loss_obj.mean_policy_loss, - "vf_loss": policy.loss_obj.mean_vf_loss, - "kl_loss": policy.loss_obj.mean_kl_loss, - "inner_kl": policy.loss_obj.mean_inner_kl, - "entropy": policy.loss_obj.mean_entropy, - } - - class KLCoeffMixin: def __init__(self, config): self.kl_coeff_val = ( @@ -364,33 +284,141 @@ class KLCoeffMixin: return self.kl_coeff_val -def maml_optimizer_fn(policy, config): - """ - Workers use simple SGD for inner adaptation - Meta-Policy uses Adam optimizer for meta-update - """ - if not config["worker_index"]: - policy.meta_opt = torch.optim.Adam(policy.model.parameters(), lr=config["lr"]) - return policy.meta_opt - return torch.optim.SGD(policy.model.parameters(), lr=config["inner_lr"]) +class MAMLTorchPolicy(ComputeGAEMixIn, ValueNetworkMixin, KLCoeffMixin, TorchPolicyV2): + """PyTorch policy class used with MAMLTrainer.""" + def __init__(self, observation_space, action_space, config): + config = dict(ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, **config) + validate_config(self, observation_space, action_space, config) -def setup_mixins(policy, obs_space, action_space, config): - ValueNetworkMixin.__init__(policy, config) - KLCoeffMixin.__init__(policy, config) + TorchPolicyV2.__init__( + self, + observation_space, + action_space, + config, + max_seq_len=config["model"]["max_seq_len"], + ) + ComputeGAEMixIn.__init__(self) + KLCoeffMixin.__init__(self, config) + ValueNetworkMixin.__init__(self, config) -MAMLTorchPolicy = build_policy_class( - name="MAMLTorchPolicy", - framework="torch", - get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, - loss_fn=maml_loss, - stats_fn=maml_stats, - optimizer_fn=maml_optimizer_fn, - extra_action_out_fn=vf_preds_fetches, - postprocess_fn=compute_gae_for_sample_batch, - extra_grad_process_fn=apply_grad_clipping, - before_init=setup_config, - after_init=setup_mixins, - mixins=[KLCoeffMixin], -) + # TODO: Don't require users to call this manually. + self._initialize_loss_from_dummy_batch() + + @override(TorchPolicyV2) + def loss( + self, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + """Constructs the loss function. + + Args: + model: The Model to calculate the loss for. + dist_class: The action distr. class. + train_batch: The training data. + + Returns: + The PPO loss tensor given the input batch. + """ + logits, state = model(train_batch) + self.cur_lr = self.config["lr"] + + if self.config["worker_index"]: + self.loss_obj = WorkerLoss( + model=model, + dist_class=dist_class, + actions=train_batch[SampleBatch.ACTIONS], + curr_logits=logits, + behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], + advantages=train_batch[Postprocessing.ADVANTAGES], + value_fn=model.value_function(), + value_targets=train_batch[Postprocessing.VALUE_TARGETS], + vf_preds=train_batch[SampleBatch.VF_PREDS], + cur_kl_coeff=0.0, + entropy_coeff=self.config["entropy_coeff"], + clip_param=self.config["clip_param"], + vf_clip_param=self.config["vf_clip_param"], + vf_loss_coeff=self.config["vf_loss_coeff"], + clip_loss=False, + ) + else: + self.var_list = model.named_parameters() + + # `split` may not exist yet (during test-loss call), use a dummy value. + # Cannot use get here due to train_batch being a TrackingDict. + if "split" in train_batch: + split = train_batch["split"] + else: + split_shape = ( + self.config["inner_adaptation_steps"], + self.config["num_workers"], + ) + split_const = int( + train_batch["obs"].shape[0] // (split_shape[0] * split_shape[1]) + ) + split = torch.ones(split_shape, dtype=int) * split_const + self.loss_obj = MAMLLoss( + model=model, + dist_class=dist_class, + value_targets=train_batch[Postprocessing.VALUE_TARGETS], + advantages=train_batch[Postprocessing.ADVANTAGES], + actions=train_batch[SampleBatch.ACTIONS], + behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], + vf_preds=train_batch[SampleBatch.VF_PREDS], + cur_kl_coeff=self.kl_coeff_val, + policy_vars=self.var_list, + obs=train_batch[SampleBatch.CUR_OBS], + num_tasks=self.config["num_workers"], + split=split, + config=self.config, + inner_adaptation_steps=self.config["inner_adaptation_steps"], + entropy_coeff=self.config["entropy_coeff"], + clip_param=self.config["clip_param"], + vf_clip_param=self.config["vf_clip_param"], + vf_loss_coeff=self.config["vf_loss_coeff"], + use_gae=self.config["use_gae"], + meta_opt=self.meta_opt, + ) + + return self.loss_obj.loss + + @override(TorchPolicyV2) + def optimizer( + self, + ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]: + """ + Workers use simple SGD for inner adaptation + Meta-Policy uses Adam optimizer for meta-update + """ + if not self.config["worker_index"]: + self.meta_opt = torch.optim.Adam( + self.model.parameters(), lr=self.config["lr"] + ) + return self.meta_opt + return torch.optim.SGD(self.model.parameters(), lr=self.config["inner_lr"]) + + @override(TorchPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + if self.config["worker_index"]: + return convert_to_numpy({"worker_loss": self.loss_obj.loss}) + else: + return convert_to_numpy( + { + "cur_kl_coeff": self.kl_coeff_val, + "cur_lr": self.cur_lr, + "total_loss": self.loss_obj.loss, + "policy_loss": self.loss_obj.mean_policy_loss, + "vf_loss": self.loss_obj.mean_vf_loss, + "kl_loss": self.loss_obj.mean_kl_loss, + "inner_kl": self.loss_obj.mean_inner_kl, + "entropy": self.loss_obj.mean_entropy, + } + ) + + def extra_grad_process( + self, optimizer: "torch.optim.Optimizer", loss: TensorType + ) -> Dict[str, TensorType]: + return apply_grad_clipping(self, optimizer, loss) diff --git a/rllib/algorithms/marwil/__init__.py b/rllib/algorithms/marwil/__init__.py index 930e248e7..aa01ec98f 100644 --- a/rllib/algorithms/marwil/__init__.py +++ b/rllib/algorithms/marwil/__init__.py @@ -1,13 +1,17 @@ from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG -from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy +from ray.rllib.algorithms.marwil.marwil_tf_policy import ( + MARWILDynamicTFPolicy, + MARWILEagerTFPolicy, +) from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy __all__ = [ "BCTrainer", "BC_DEFAULT_CONFIG", "DEFAULT_CONFIG", - "MARWILTFPolicy", + "MARWILDynamicTFPolicy", + "MARWILEagerTFPolicy", "MARWILTorchPolicy", "MARWILTrainer", ] diff --git a/rllib/algorithms/marwil/marwil.py b/rllib/algorithms/marwil/marwil.py index 521a3c5e3..34cca8d86 100644 --- a/rllib/algorithms/marwil/marwil.py +++ b/rllib/algorithms/marwil/marwil.py @@ -1,7 +1,6 @@ from typing import Type from ray.rllib.agents.trainer import Trainer, with_common_config -from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy from ray.rllib.utils.replay_buffers.utils import validate_buffer_config from ray.rllib.execution.rollout_ops import ( synchronous_parallel_sample, @@ -123,8 +122,16 @@ class MARWILTrainer(Trainer): ) return MARWILTorchPolicy + elif config["framework"] == "tf": + from ray.rllib.algorithms.marwil.marwil_tf_policy import ( + MARWILDynamicTFPolicy, + ) + + return MARWILDynamicTFPolicy else: - return MARWILTFPolicy + from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILEagerTFPolicy + + return MARWILEagerTFPolicy @override(Trainer) def training_iteration(self) -> ResultDict: diff --git a/rllib/algorithms/marwil/marwil_tf_policy.py b/rllib/algorithms/marwil/marwil_tf_policy.py index 17417c715..ada8ba66b 100644 --- a/rllib/algorithms/marwil/marwil_tf_policy.py +++ b/rllib/algorithms/marwil/marwil_tf_policy.py @@ -1,100 +1,69 @@ import logging -import gym -from typing import Optional, Dict +from typing import Any, Dict, List, Optional, Type, Union import ray -from ray.rllib.agents.ppo.ppo_tf_policy import compute_and_clip_gradients -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing -from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils.framework import try_import_tf, get_variable -from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable -from ray.rllib.policy.policy import Policy -from ray.rllib.utils.typing import TrainerConfigDict, TensorType, PolicyID from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_mixins import ComputeAndClipGradsMixIn, ValueNetworkMixin +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf, get_variable +from ray.rllib.utils.tf_utils import explained_variance +from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() logger = logging.getLogger(__name__) -class ValueNetworkMixin: - def __init__( +class PostprocessAdvantages: + """Marwil's custom trajectory post-processing mixin.""" + + def __init__(self): + pass + + def postprocess_trajectory( self, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, + episode: Optional["Episode"] = None, ): - - # Input dict is provided to us automatically via the Model's - # requirements. It's a single-timestep (last one in trajectory) - # input_dict. - @make_tf_callable(self.get_session()) - def value(**input_dict): - model_out, _ = self.model(input_dict) - # [0] = remove the batch dim. - return self.model.value_function()[0] - - self._value = value - - -def postprocess_advantages( - policy: Policy, - sample_batch: SampleBatch, - other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None, - episode=None, -) -> SampleBatch: - """Postprocesses a trajectory and returns the processed trajectory. - - The trajectory contains only data from one episode and from one agent. - - If `config.batch_mode=truncate_episodes` (default), sample_batch may - contain a truncated (at-the-end) episode, in case the - `config.rollout_fragment_length` was reached by the sampler. - - If `config.batch_mode=complete_episodes`, sample_batch will contain - exactly one episode (no matter how long). - New columns can be added to sample_batch and existing ones may be altered. - - Args: - policy (Policy): The Policy used to generate the trajectory - (`sample_batch`) - sample_batch (SampleBatch): The SampleBatch to postprocess. - other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional - dict of AgentIDs mapping to other agents' trajectory data (from the - same episode). NOTE: The other agents use the same policy. - episode (Optional[Episode]): Optional multi-agent episode - object in which the agents operated. - - Returns: - SampleBatch: The postprocessed, modified SampleBatch (or a new one). - """ - - # Trajectory is actually complete -> last r=0.0. - if sample_batch[SampleBatch.DONES][-1]: - last_r = 0.0 - # Trajectory has been truncated -> last r=VF estimate of last obs. - else: - # Input dict is provided to us automatically via the Model's - # requirements. It's a single-timestep (last one in trajectory) - # input_dict. - # Create an input dict according to the Model's requirements. - index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1 - input_dict = sample_batch.get_single_step_input_dict( - policy.model.view_requirements, index=index + sample_batch = super().postprocess_trajectory( + sample_batch, other_agent_batches, episode ) - last_r = policy._value(**input_dict) - # Adds the "advantages" (which in the case of MARWIL are simply the - # discounted cummulative rewards) to the SampleBatch. - return compute_advantages( - sample_batch, - last_r, - policy.config["gamma"], - # We just want the discounted cummulative rewards, so we won't need - # GAE nor critic (use_critic=True: Subtract vf-estimates from returns). - use_gae=False, - use_critic=False, - ) + # Trajectory is actually complete -> last r=0.0. + if sample_batch[SampleBatch.DONES][-1]: + last_r = 0.0 + # Trajectory has been truncated -> last r=VF estimate of last obs. + else: + # Input dict is provided to us automatically via the Model's + # requirements. It's a single-timestep (last one in trajectory) + # input_dict. + # Create an input dict according to the Model's requirements. + index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1 + input_dict = sample_batch.get_single_step_input_dict( + self.model.view_requirements, index=index + ) + last_r = self._value(**input_dict) + + # Adds the "advantages" (which in the case of MARWIL are simply the + # discounted cummulative rewards) to the SampleBatch. + return compute_advantages( + sample_batch, + last_r, + self.config["gamma"], + # We just want the discounted cummulative rewards, so we won't need + # GAE nor critic (use_critic=True: Subtract vf-estimates from returns). + use_gae=False, + use_critic=False, + ) class MARWILLoss: @@ -174,69 +143,100 @@ class MARWILLoss: self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss -def marwil_loss( - policy: Policy, - model: ModelV2, - dist_class: ActionDistribution, - train_batch: SampleBatch, -) -> TensorType: - model_out, _ = model(train_batch) - action_dist = dist_class(model_out, model) - value_estimates = model.value_function() +# We need this builder function because we want to share the same +# custom logics between TF1 dynamic and TF2 eager policies. +def get_marwil_tf_policy(base: type) -> type: + """Construct a MARWILTFPolicy inheriting either dynamic or eager base policies. - policy.loss = MARWILLoss( - policy, - value_estimates, - action_dist, - train_batch, - policy.config["vf_coeff"], - policy.config["beta"], - ) + Args: + base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. - return policy.loss.total_loss + Returns: + A TF Policy to be used with MAMLTrainer. + """ + + class MARWILTFPolicy( + ComputeAndClipGradsMixIn, ValueNetworkMixin, PostprocessAdvantages, base + ): + def __init__( + self, + obs_space, + action_space, + config, + existing_model=None, + existing_inputs=None, + ): + # First thing first, enable eager execution if necessary. + base.enable_eager_execution_if_necessary() + + config = dict(ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG, **config) + + # Initialize base class. + base.__init__( + self, + obs_space, + action_space, + config, + existing_inputs=existing_inputs, + existing_model=existing_model, + ) + + ComputeAndClipGradsMixIn.__init__(self) + ValueNetworkMixin.__init__(self, config) + PostprocessAdvantages.__init__(self) + + # Not needed for pure BC. + if config["beta"] != 0.0: + # Set up a tf-var for the moving avg (do this here to make it work + # with eager mode); "c^2" in the paper. + self._moving_average_sqd_adv_norm = get_variable( + config["moving_average_sqd_adv_norm_start"], + framework="tf", + tf_name="moving_average_of_advantage_norm", + trainable=False, + ) + + # Note: this is a bit ugly, but loss and optimizer initialization must + # happen after all the MixIns are initialized. + self.maybe_initialize_optimizer_and_loss() + + @override(base) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + model_out, _ = model(train_batch) + action_dist = dist_class(model_out, model) + value_estimates = model.value_function() + + self.loss = MARWILLoss( + self, + value_estimates, + action_dist, + train_batch, + self.config["vf_coeff"], + self.config["beta"], + ) + + return self.loss.total_loss + + @override(base) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + stats = { + "policy_loss": self.loss.p_loss, + "total_loss": self.loss.total_loss, + } + if self.config["beta"] != 0.0: + stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm + stats["vf_explained_var"] = self.loss.explained_variance + stats["vf_loss"] = self.loss.v_loss + + return stats + + return MARWILTFPolicy -def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: - stats = { - "policy_loss": policy.loss.p_loss, - "total_loss": policy.loss.total_loss, - } - if policy.config["beta"] != 0.0: - stats["moving_average_sqd_adv_norm"] = policy._moving_average_sqd_adv_norm - stats["vf_explained_var"] = policy.loss.explained_variance - stats["vf_loss"] = policy.loss.v_loss - - return stats - - -def setup_mixins( - policy: Policy, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -) -> None: - # Setup Value branch of our NN. - ValueNetworkMixin.__init__(policy, obs_space, action_space, config) - - # Not needed for pure BC. - if policy.config["beta"] != 0.0: - # Set up a tf-var for the moving avg (do this here to make it work - # with eager mode); "c^2" in the paper. - policy._moving_average_sqd_adv_norm = get_variable( - policy.config["moving_average_sqd_adv_norm_start"], - framework="tf", - tf_name="moving_average_of_advantage_norm", - trainable=False, - ) - - -MARWILTFPolicy = build_tf_policy( - name="MARWILTFPolicy", - get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG, - loss_fn=marwil_loss, - stats_fn=stats, - postprocess_fn=postprocess_advantages, - before_loss_init=setup_mixins, - compute_gradients_fn=compute_and_clip_gradients, - mixins=[ValueNetworkMixin], -) +MARWILDynamicTFPolicy = get_marwil_tf_policy(DynamicTFPolicyV2) +MARWILEagerTFPolicy = get_marwil_tf_policy(EagerTFPolicyV2) diff --git a/rllib/algorithms/marwil/marwil_torch_policy.py b/rllib/algorithms/marwil/marwil_torch_policy.py index 7c4b20bb7..10734fc8c 100644 --- a/rllib/algorithms/marwil/marwil_torch_policy.py +++ b/rllib/algorithms/marwil/marwil_torch_policy.py @@ -1,122 +1,124 @@ -import gym -from typing import Dict +from typing import Dict, List, Type, Union import ray -from ray.rllib.algorithms.marwil.marwil_tf_policy import postprocess_advantages +from ray.rllib.algorithms.marwil.marwil_tf_policy import PostprocessAdvantages from ray.rllib.evaluation.postprocessing import Postprocessing -from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_mixins import ValueNetworkMixin +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 +from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import apply_grad_clipping, explained_variance -from ray.rllib.utils.typing import TrainerConfigDict, TensorType -from ray.rllib.policy.policy import Policy -from ray.rllib.models.action_dist import ActionDistribution -from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.utils.typing import TensorType torch, _ = try_import_torch() -def marwil_loss( - policy: Policy, - model: ModelV2, - dist_class: ActionDistribution, - train_batch: SampleBatch, -) -> TensorType: - model_out, _ = model(train_batch) - action_dist = dist_class(model_out, model) - actions = train_batch[SampleBatch.ACTIONS] - # log\pi_\theta(a|s) - logprobs = action_dist.logp(actions) +class MARWILTorchPolicy(ValueNetworkMixin, PostprocessAdvantages, TorchPolicyV2): + """PyTorch policy class used with MarwilTrainer.""" - # Advantage estimation. - if policy.config["beta"] != 0.0: - cumulative_rewards = train_batch[Postprocessing.ADVANTAGES] - state_values = model.value_function() - adv = cumulative_rewards - state_values - adv_squared_mean = torch.mean(torch.pow(adv, 2.0)) + def __init__(self, observation_space, action_space, config): + config = dict(ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG, **config) - explained_var = explained_variance(cumulative_rewards, state_values) - policy.explained_variance = torch.mean(explained_var) - - # Policy loss. - # Update averaged advantage norm. - rate = policy.config["moving_average_sqd_adv_norm_update_rate"] - policy._moving_average_sqd_adv_norm.add_( - rate * (adv_squared_mean - policy._moving_average_sqd_adv_norm) + TorchPolicyV2.__init__( + self, + observation_space, + action_space, + config, + max_seq_len=config["model"]["max_seq_len"], ) - # Exponentially weighted advantages. - exp_advs = torch.exp( - policy.config["beta"] - * (adv / (1e-8 + torch.pow(policy._moving_average_sqd_adv_norm, 0.5))) - ).detach() - # Value loss. - policy.v_loss = 0.5 * adv_squared_mean - else: - # Policy loss (simple BC loss term). - exp_advs = 1.0 - # Value loss. - policy.v_loss = 0.0 - # logprob loss alone tends to push action distributions to - # have very low entropy, resulting in worse performance for - # unfamiliar situations. - # A scaled logstd loss term encourages stochasticity, thus - # alleviate the problem to some extent. - logstd_coeff = policy.config["bc_logstd_coeff"] - if logstd_coeff > 0.0: - logstds = torch.mean(action_dist.log_std, dim=1) - else: - logstds = 0.0 + ValueNetworkMixin.__init__(self, config) + PostprocessAdvantages.__init__(self) - policy.p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds)) + # Not needed for pure BC. + if config["beta"] != 0.0: + # Set up a torch-var for the squared moving avg. advantage norm. + self._moving_average_sqd_adv_norm = torch.tensor( + [config["moving_average_sqd_adv_norm_start"]], + dtype=torch.float32, + requires_grad=False, + ).to(self.device) - # Combine both losses. - policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * policy.v_loss + # TODO: Don't require users to call this manually. + self._initialize_loss_from_dummy_batch() - return policy.total_loss + @override(TorchPolicyV2) + def loss( + self, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + model_out, _ = model(train_batch) + action_dist = dist_class(model_out, model) + actions = train_batch[SampleBatch.ACTIONS] + # log\pi_\theta(a|s) + logprobs = action_dist.logp(actions) + # Advantage estimation. + if self.config["beta"] != 0.0: + cumulative_rewards = train_batch[Postprocessing.ADVANTAGES] + state_values = model.value_function() + adv = cumulative_rewards - state_values + adv_squared_mean = torch.mean(torch.pow(adv, 2.0)) -def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: - stats = { - "policy_loss": policy.p_loss, - "total_loss": policy.total_loss, - } - if policy.config["beta"] != 0.0: - stats["moving_average_sqd_adv_norm"] = policy._moving_average_sqd_adv_norm - stats["vf_explained_var"] = policy.explained_variance - stats["vf_loss"] = policy.v_loss + explained_var = explained_variance(cumulative_rewards, state_values) + self.explained_variance = torch.mean(explained_var) - return stats + # Policy loss. + # Update averaged advantage norm. + rate = self.config["moving_average_sqd_adv_norm_update_rate"] + self._moving_average_sqd_adv_norm.add_( + rate * (adv_squared_mean - self._moving_average_sqd_adv_norm) + ) + # Exponentially weighted advantages. + exp_advs = torch.exp( + self.config["beta"] + * (adv / (1e-8 + torch.pow(self._moving_average_sqd_adv_norm, 0.5))) + ).detach() + # Value loss. + self.v_loss = 0.5 * adv_squared_mean + else: + # Policy loss (simple BC loss term). + exp_advs = 1.0 + # Value loss. + self.v_loss = 0.0 + # logprob loss alone tends to push action distributions to + # have very low entropy, resulting in worse performance for + # unfamiliar situations. + # A scaled logstd loss term encourages stochasticity, thus + # alleviate the problem to some extent. + logstd_coeff = self.config["bc_logstd_coeff"] + if logstd_coeff > 0.0: + logstds = torch.mean(action_dist.log_std, dim=1) + else: + logstds = 0.0 -def setup_mixins( - policy: Policy, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -) -> None: - # Setup Value branch of our NN. - ValueNetworkMixin.__init__(policy, config) + self.p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds)) - # Not needed for pure BC. - if policy.config["beta"] != 0.0: - # Set up a torch-var for the squared moving avg. advantage norm. - policy._moving_average_sqd_adv_norm = torch.tensor( - [policy.config["moving_average_sqd_adv_norm_start"]], - dtype=torch.float32, - requires_grad=False, - ).to(policy.device) + # Combine both losses. + self.total_loss = self.p_loss + self.config["vf_coeff"] * self.v_loss + return self.total_loss -MARWILTorchPolicy = build_policy_class( - name="MARWILTorchPolicy", - framework="torch", - loss_fn=marwil_loss, - get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG, - stats_fn=stats, - postprocess_fn=postprocess_advantages, - extra_grad_process_fn=apply_grad_clipping, - before_loss_init=setup_mixins, - mixins=[ValueNetworkMixin], -) + @override(TorchPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + stats = { + "policy_loss": self.p_loss, + "total_loss": self.total_loss, + } + if self.config["beta"] != 0.0: + stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm + stats["vf_explained_var"] = self.explained_variance + stats["vf_loss"] = self.v_loss + return convert_to_numpy(stats) + + def extra_grad_process( + self, optimizer: "torch.optim.Optimizer", loss: TensorType + ) -> Dict[str, TensorType]: + return apply_grad_clipping(self, optimizer, loss) diff --git a/rllib/algorithms/marwil/tests/test_marwil.py b/rllib/algorithms/marwil/tests/test_marwil.py index d4b721439..cc19194ed 100644 --- a/rllib/algorithms/marwil/tests/test_marwil.py +++ b/rllib/algorithms/marwil/tests/test_marwil.py @@ -5,6 +5,8 @@ import unittest import ray import ray.rllib.algorithms.marwil as marwil +from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILEagerTFPolicy +from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.offline import JsonReader from ray.rllib.utils.framework import try_import_tf, try_import_torch @@ -182,9 +184,7 @@ class TestMARWIL(unittest.TestCase): batch.set_get_interceptor(None) postprocessed_batch = policy.postprocess_trajectory(batch) loss_func = ( - marwil.marwil_tf_policy.marwil_loss - if fw != "torch" - else marwil.marwil_torch_policy.marwil_loss + MARWILEagerTFPolicy.loss if fw != "torch" else MARWILTorchPolicy.loss ) if fw != "tf": policy._lazy_tensor_dict(postprocessed_batch) diff --git a/rllib/algorithms/mbmpo/mbmpo.py b/rllib/algorithms/mbmpo/mbmpo.py index d2e582c17..f2a5e2607 100644 --- a/rllib/algorithms/mbmpo/mbmpo.py +++ b/rllib/algorithms/mbmpo/mbmpo.py @@ -4,7 +4,6 @@ from typing import List, Type import ray from ray.rllib.agents import with_common_config -from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration from ray.rllib.agents.trainer import Trainer @@ -374,6 +373,8 @@ class MBMPOTrainer(Trainer): @override(Trainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: + from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy + return MBMPOTorchPolicy @staticmethod diff --git a/rllib/algorithms/mbmpo/mbmpo_torch_policy.py b/rllib/algorithms/mbmpo/mbmpo_torch_policy.py index 46a4e9e8c..b7f950324 100644 --- a/rllib/algorithms/mbmpo/mbmpo_torch_policy.py +++ b/rllib/algorithms/mbmpo/mbmpo_torch_policy.py @@ -1,135 +1,90 @@ -import gym from gym.spaces import Box, Discrete import logging from typing import Tuple, Type import ray -from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches -from ray.rllib.algorithms.maml.maml_torch_policy import ( - setup_mixins, - maml_loss, - maml_stats, - maml_optimizer_fn, - KLCoeffMixin, -) -from ray.rllib.agents.ppo.ppo_tf_policy import setup_config -from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch +from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_utils import apply_grad_clipping -from ray.rllib.utils.typing import TrainerConfigDict torch, nn = try_import_torch() logger = logging.getLogger(__name__) -def validate_spaces( - policy: Policy, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -) -> None: - """Validates the observation- and action spaces used for the Policy. +class MBMPOTorchPolicy(MAMLTorchPolicy): + def __init__(self, observation_space, action_space, config): + # Validate spaces. + # Only support single Box or single Discrete spaces. + if not isinstance(action_space, (Box, Discrete)): + raise UnsupportedSpaceException( + "Action space ({}) of {} is not supported for " + "MB-MPO. Must be [Box|Discrete].".format(action_space, self) + ) + # If Box, make sure it's a 1D vector space. + elif isinstance(action_space, Box) and len(action_space.shape) > 1: + raise UnsupportedSpaceException( + "Action space ({}) of {} has multiple dimensions " + "{}. ".format(action_space, self, action_space.shape) + + "Consider reshaping this into a single dimension Box space " + "or using the multi-agent API." + ) - Args: - policy (Policy): The policy, whose spaces are being validated. - observation_space (gym.spaces.Space): The observation space to - validate. - action_space (gym.spaces.Space): The action space to validate. - config (TrainerConfigDict): The Policy's config dict. + config = dict(ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG, **config) + super().__init__(observation_space, action_space, config) - Raises: - UnsupportedSpaceException: If one of the spaces is not supported. - """ - # Only support single Box or single Discrete spaces. - if not isinstance(action_space, (Box, Discrete)): - raise UnsupportedSpaceException( - "Action space ({}) of {} is not supported for " - "MB-MPO. Must be [Box|Discrete].".format(action_space, policy) - ) - # If Box, make sure it's a 1D vector space. - elif isinstance(action_space, Box) and len(action_space.shape) > 1: - raise UnsupportedSpaceException( - "Action space ({}) of {} has multiple dimensions " - "{}. ".format(action_space, policy, action_space.shape) - + "Consider reshaping this into a single dimension Box space " - "or using the multi-agent API." + def make_model_and_action_dist( + self, + ) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]: + """Constructs the necessary ModelV2 and action dist class for the Policy. + + Args: + obs_space (gym.spaces.Space): The observation space. + action_space (gym.spaces.Space): The action space. + config (TrainerConfigDict): The SAC trainer's config dict. + + Returns: + ModelV2: The ModelV2 to be used by the Policy. Note: An additional + target model will be created in this function and assigned to + `policy.target_model`. + """ + # Get the output distribution class for predicting rewards and next-obs. + self.distr_cls_next_obs, num_outputs = ModelCatalog.get_action_dist( + self.observation_space, + self.config, + dist_type="deterministic", + framework="torch", ) + # Build one dynamics model if we are a Worker. + # If we are the main MAML learner, build n (num_workers) dynamics Models + # for being able to create checkpoints for the current state of training. + device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + self.dynamics_model = ModelCatalog.get_model_v2( + self.observation_space, + self.action_space, + num_outputs=num_outputs, + model_config=self.config["dynamics_model"], + framework="torch", + name="dynamics_ensemble", + ).to(device) -def make_model_and_action_dist( - policy: Policy, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict, -) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]: - """Constructs the necessary ModelV2 and action dist class for the Policy. + action_dist, num_outputs = ModelCatalog.get_action_dist( + self.action_space, self.config, framework="torch" + ) + # Create the pi-model and register it with the Policy. + self.pi = ModelCatalog.get_model_v2( + self.observation_space, + self.action_space, + num_outputs=num_outputs, + model_config=self.config["model"], + framework="torch", + name="policy_model", + ) - Args: - policy (Policy): The TFPolicy that will use the models. - obs_space (gym.spaces.Space): The observation space. - action_space (gym.spaces.Space): The action space. - config (TrainerConfigDict): The SAC trainer's config dict. - - Returns: - ModelV2: The ModelV2 to be used by the Policy. Note: An additional - target model will be created in this function and assigned to - `policy.target_model`. - """ - # Get the output distribution class for predicting rewards and next-obs. - policy.distr_cls_next_obs, num_outputs = ModelCatalog.get_action_dist( - obs_space, config, dist_type="deterministic", framework="torch" - ) - - # Build one dynamics model if we are a Worker. - # If we are the main MAML learner, build n (num_workers) dynamics Models - # for being able to create checkpoints for the current state of training. - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - policy.dynamics_model = ModelCatalog.get_model_v2( - obs_space, - action_space, - num_outputs=num_outputs, - model_config=config["dynamics_model"], - framework="torch", - name="dynamics_ensemble", - ).to(device) - - action_dist, num_outputs = ModelCatalog.get_action_dist( - action_space, config, framework="torch" - ) - # Create the pi-model and register it with the Policy. - policy.pi = ModelCatalog.get_model_v2( - obs_space, - action_space, - num_outputs=num_outputs, - model_config=config["model"], - framework="torch", - name="policy_model", - ) - - return policy.pi, action_dist - - -# Build a child class of `TorchPolicy`, given the custom functions defined -# above. -MBMPOTorchPolicy = build_policy_class( - name="MBMPOTorchPolicy", - framework="torch", - get_default_config=lambda: ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG, - make_model_and_action_dist=make_model_and_action_dist, - loss_fn=maml_loss, - stats_fn=maml_stats, - optimizer_fn=maml_optimizer_fn, - extra_action_out_fn=vf_preds_fetches, - postprocess_fn=compute_gae_for_sample_batch, - extra_grad_process_fn=apply_grad_clipping, - before_init=setup_config, - after_init=setup_mixins, - mixins=[KLCoeffMixin], -) + return self.pi, action_dist diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index d5b6d3317..f2f5db9d2 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -29,6 +29,8 @@ from ray.rllib.utils.annotations import ( DeveloperAPI, ExperimentalAPI, OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, + is_overridden, ) from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.exploration.exploration import Exploration @@ -155,6 +157,20 @@ class Policy(metaclass=ABCMeta): # Child classes may set this. self.dist_class: Optional[Type] = None + # Initialize view requirements. + self.init_view_requirements() + + # Whether the Model's initial state (method) has been added + # automatically based on the given view requirements of the model. + self._model_init_state_automatically_added = False + + @DeveloperAPI + def init_view_requirements(self): + """Maximal view requirements dict for `learn_on_batch()` and + `compute_actions` calls. + Specific policies can override this function to provide custom + list of view requirements. + """ # Maximal view requirements dict for `learn_on_batch()` and # `compute_actions` calls. # View requirements will be automatically filtered out later based @@ -167,9 +183,6 @@ class Policy(metaclass=ABCMeta): for k, v in view_reqs.items(): if k not in self.view_requirements: self.view_requirements[k] = v - # Whether the Model's initial state (method) has been added - # automatically based on the given view requirements of the model. - self._model_init_state_automatically_added = False @DeveloperAPI def compute_single_action( @@ -413,6 +426,7 @@ class Policy(metaclass=ABCMeta): raise NotImplementedError @DeveloperAPI + @OverrideToImplementCustomLogic_CallToSuperRecommended def postprocess_trajectory( self, sample_batch: SampleBatch, @@ -739,9 +753,7 @@ class Policy(metaclass=ABCMeta): # steps). # Make sure, we keep global_timestep as a Tensor for tf-eager # (leads to memory leaks if not doing so). - from ray.rllib.policy.eager_tf_policy import EagerTFPolicy - - if self.framework in ["tf2", "tfe"] and isinstance(self, EagerTFPolicy): + if self.framework in ["tfe", "tf2"]: self.global_timestep.assign(global_vars["timestep"]) else: self.global_timestep = global_vars["timestep"] @@ -952,11 +964,19 @@ class Policy(metaclass=ABCMeta): train_batch[SampleBatch.SEQ_LENS] = seq_lens train_batch.count = self._dummy_batch.count # Call the loss function, if it exists. + # TODO(jungong) : clean up after all agents get migrated. + # We should simply do self.loss(...) here. if self._loss is not None: self._loss(self, self.model, self.dist_class, train_batch) + elif is_overridden(self.loss): + self.loss(self.model, self.dist_class, train_batch) # Call the stats fn, if given. + # TODO(jungong) : clean up after all agents get migrated. + # We should simply do self.stats_fn(train_batch) here. if stats_fn is not None: stats_fn(self, train_batch) + if hasattr(self, "stats_fn"): + self.stats_fn(train_batch) # Re-enable tracing. self._no_tracing = False @@ -974,7 +994,7 @@ class Policy(metaclass=ABCMeta): self.view_requirements[key] = ViewRequirement( used_for_compute_actions=False ) - if self._loss: + if self._loss or is_overridden(self.loss): # Tag those only needed for post-processing (with some # exceptions). for key in self._dummy_batch.accessed_keys: diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 293c38bd0..491557b2a 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -150,7 +150,7 @@ class TFPolicy(Policy): super().__init__(observation_space, action_space, config) # Get devices to build the graph on. - worker_idx = self.config.get("worker_index", 0) + worker_idx = config.get("worker_index", 0) if not config["_fake_gpus"] and ray.worker._mode() == ray.worker.LOCAL_MODE: num_gpus = 0 elif worker_idx == 0: @@ -237,7 +237,7 @@ class TFPolicy(Policy): self._action_input = action_input # For logp calculations. self._dist_inputs = dist_inputs self.dist_class = dist_class - + self._cached_extra_action_out = None self._state_inputs = state_inputs or [] self._state_outputs = state_outputs or [] self._seq_lens = seq_lens @@ -784,6 +784,16 @@ class TFPolicy(Policy): @DeveloperAPI def extra_compute_action_fetches(self) -> Dict[str, TensorType]: + # Cache graph fetches for action computation for better + # performance. + # This function is called every time the static graph is run + # to compute actions. + if not self._cached_extra_action_out: + self._cached_extra_action_out = self.extra_action_out_fn() + return self._cached_extra_action_out + + @DeveloperAPI + def extra_action_out_fn(self) -> Dict[str, TensorType]: """Extra values to fetch and return from compute_actions(). By default we return action probability/log-likelihood info diff --git a/rllib/policy/torch_policy_v2.py b/rllib/policy/torch_policy_v2.py index 63fabaae8..62a8cc5d5 100644 --- a/rllib/policy/torch_policy_v2.py +++ b/rllib/policy/torch_policy_v2.py @@ -230,7 +230,7 @@ class TorchPolicyV2(Policy): dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: - """Constructs the loss for Proximal Policy Objective. + """Constructs the loss function. Args: model: The Model to calculate the loss for. @@ -396,20 +396,6 @@ class TorchPolicyV2(Policy): """ return {} - @DeveloperAPI - @OverrideToImplementCustomLogic_CallToSuperRecommended - def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]: - """Return dict of extra grad info. - - Args: - train_batch: The training batch for which to produce - extra grad info for. - - Returns: - The info dict carrying grad info per str key. - """ - return {} - @override(Policy) @DeveloperAPI @OverrideToImplementCustomLogic_CallToSuperRecommended @@ -418,8 +404,28 @@ class TorchPolicyV2(Policy): sample_batch: SampleBatch, other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, episode: Optional["Episode"] = None, - ): - """Additional custom postprocessing of SampleBatch.""" + ) -> SampleBatch: + """Postprocesses a trajectory and returns the processed trajectory. + + The trajectory contains only data from one episode and from one agent. + - If `config.batch_mode=truncate_episodes` (default), sample_batch may + contain a truncated (at-the-end) episode, in case the + `config.rollout_fragment_length` was reached by the sampler. + - If `config.batch_mode=complete_episodes`, sample_batch will contain + exactly one episode (no matter how long). + New columns can be added to sample_batch and existing ones may be altered. + + Args: + sample_batch (SampleBatch): The SampleBatch to postprocess. + other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional + dict of AgentIDs mapping to other agents' trajectory data (from the + same episode). NOTE: The other agents use the same policy. + episode (Optional[Episode]): Optional multi-agent episode + object in which the agents operated. + + Returns: + SampleBatch: The postprocessed, modified SampleBatch (or a new one). + """ return sample_batch @DeveloperAPI @@ -775,7 +781,7 @@ class TorchPolicyV2(Policy): for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)): batch_fetches[f"tower_{i}"].update( { - LEARNER_STATS_KEY: self.extra_grad_info(batch), + LEARNER_STATS_KEY: self.stats_fn(batch), "model": model.metrics(), } ) @@ -810,7 +816,7 @@ class TorchPolicyV2(Policy): all_grads, grad_info = tower_outputs[0] grad_info["allreduce_latency"] /= len(self._optimizers) - grad_info.update(self.extra_grad_info(postprocessed_batch)) + grad_info.update(self.stats_fn(postprocessed_batch)) fetches = self.extra_compute_grad_fetches()