[RLlib] Migrate MAML, MB-MPO, MARWIL, and BC to use Policy sub-classing implementation. (#24914)

This commit is contained in:
Jun Gong 2022-05-20 05:10:59 -07:00 committed by GitHub
parent a2c8fe2101
commit d5a6d46049
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 719 additions and 626 deletions

View file

@ -1,13 +1,17 @@
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.algorithms.marwil.marwil_tf_policy import (
MARWILDynamicTFPolicy,
MARWILEagerTFPolicy,
)
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
__all__ = [
"BCTrainer",
"BC_DEFAULT_CONFIG",
"DEFAULT_CONFIG",
"MARWILTFPolicy",
"MARWILDynamicTFPolicy",
"MARWILEagerTFPolicy",
"MARWILTorchPolicy",
"MARWILTrainer",
]

View file

@ -241,7 +241,7 @@ def compute_and_clip_gradients(
return grads_and_vars
def setup_config(
def validate_config(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
@ -324,7 +324,7 @@ PPOTFPolicy = build_tf_policy(
stats_fn=kl_and_loss_stats,
compute_gradients_fn=compute_and_clip_gradients,
extra_action_out_fn=vf_preds_fetches,
before_init=setup_config,
before_init=validate_config,
before_loss_init=setup_mixins,
mixins=[
LearningRateSchedule,

View file

@ -2,7 +2,7 @@ import logging
from typing import Dict, List, Type, Union
import ray
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.agents.ppo.ppo_tf_policy import validate_config
from ray.rllib.evaluation.postprocessing import (
compute_gae_for_sample_batch,
Postprocessing,
@ -43,7 +43,7 @@ class PPOTorchPolicy(
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
setup_config(self, observation_space, action_space, config)
validate_config(self, observation_space, action_space, config)
TorchPolicy.__init__(
self,

View file

@ -4,8 +4,6 @@ from typing import Type
from ray.rllib.utils.sgd import standardized
from ray.rllib.agents import with_common_config
from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTFPolicy
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.worker_set import WorkerSet
@ -199,9 +197,17 @@ class MAMLTrainer(Trainer):
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
return MAMLTorchPolicy
elif config["framework"] == "tf":
from ray.rllib.algorithms.maml.maml_tf_policy import MAMLDynamicTFPolicy
return MAMLDynamicTFPolicy
else:
return MAMLTFPolicy
from ray.rllib.algorithms.maml.maml_tf_policy import MAMLEagerTFPolicy
return MAMLEagerTFPolicy
@staticmethod
@override(Trainer)

View file

@ -1,20 +1,23 @@
import logging
from typing import Dict, List, Type, Union
import ray
from ray.rllib.agents.ppo.ppo_tf_policy import (
vf_preds_fetches,
compute_and_clip_gradients,
setup_config,
)
from ray.rllib.evaluation.postprocessing import (
compute_gae_for_sample_batch,
Postprocessing,
)
from ray.rllib.agents.ppo.ppo_tf_policy import validate_config
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_mixins import ValueNetworkMixin
from ray.rllib.policy.tf_mixins import (
ComputeAndClipGradsMixIn,
ComputeGAEMixIn,
ValueNetworkMixin,
)
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
@ -326,72 +329,6 @@ class MAMLLoss(object):
return placeholder_list
def maml_loss(policy, model, dist_class, train_batch):
logits, state = model(train_batch)
policy.cur_lr = policy.config["lr"]
if policy.config["worker_index"]:
policy.loss_obj = WorkerLoss(
dist_class=dist_class,
actions=train_batch[SampleBatch.ACTIONS],
curr_logits=logits,
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
advantages=train_batch[Postprocessing.ADVANTAGES],
value_fn=model.value_function(),
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
vf_preds=train_batch[SampleBatch.VF_PREDS],
cur_kl_coeff=0.0,
entropy_coeff=policy.config["entropy_coeff"],
clip_param=policy.config["clip_param"],
vf_clip_param=policy.config["vf_clip_param"],
vf_loss_coeff=policy.config["vf_loss_coeff"],
clip_loss=False,
)
else:
policy.var_list = tf1.get_collection(
tf1.GraphKeys.TRAINABLE_VARIABLES, tf1.get_variable_scope().name
)
policy.loss_obj = MAMLLoss(
model=model,
dist_class=dist_class,
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
advantages=train_batch[Postprocessing.ADVANTAGES],
actions=train_batch[SampleBatch.ACTIONS],
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
vf_preds=train_batch[SampleBatch.VF_PREDS],
cur_kl_coeff=policy.kl_coeff,
policy_vars=policy.var_list,
obs=train_batch[SampleBatch.CUR_OBS],
num_tasks=policy.config["num_workers"],
split=train_batch["split"],
config=policy.config,
inner_adaptation_steps=policy.config["inner_adaptation_steps"],
entropy_coeff=policy.config["entropy_coeff"],
clip_param=policy.config["clip_param"],
vf_clip_param=policy.config["vf_clip_param"],
vf_loss_coeff=policy.config["vf_loss_coeff"],
use_gae=policy.config["use_gae"],
)
return policy.loss_obj.loss
def maml_stats(policy, train_batch):
if policy.config["worker_index"]:
return {"worker_loss": policy.loss_obj.loss}
else:
return {
"cur_kl_coeff": tf.cast(policy.kl_coeff, tf.float64),
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"total_loss": policy.loss_obj.loss,
"policy_loss": policy.loss_obj.mean_policy_loss,
"vf_loss": policy.loss_obj.mean_vf_loss,
"kl": policy.loss_obj.mean_kl,
"inner_kl": policy.loss_obj.mean_inner_kl,
"entropy": policy.loss_obj.mean_entropy,
}
class KLCoeffMixin:
def __init__(self, config):
self.kl_coeff_val = [config["kl_coeff"]] * config["inner_adaptation_steps"]
@ -415,41 +352,154 @@ class KLCoeffMixin:
return self.kl_coeff_val
def maml_optimizer_fn(policy, config):
# We need this builder function because we want to share the same
# custom logics between TF1 dynamic and TF2 eager policies.
def get_maml_tf_policy(base: type) -> type:
"""Construct a MAMLTFPolicy inheriting either dynamic or eager base policies.
Args:
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
Returns:
A TF Policy to be used with MAMLTrainer.
"""
Workers use simple SGD for inner adaptation
Meta-Policy uses Adam optimizer for meta-update
"""
if not config["worker_index"]:
return tf1.train.AdamOptimizer(learning_rate=config["lr"])
return tf1.train.GradientDescentOptimizer(learning_rate=config["inner_lr"])
class MAMLTFPolicy(
ComputeGAEMixIn, ComputeAndClipGradsMixIn, KLCoeffMixin, ValueNetworkMixin, base
):
def __init__(
self,
obs_space,
action_space,
config,
existing_model=None,
existing_inputs=None,
):
# First thing first, enable eager execution if necessary.
base.enable_eager_execution_if_necessary()
config = dict(ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, **config)
validate_config(self, obs_space, action_space, config)
# Initialize base class.
base.__init__(
self,
obs_space,
action_space,
config,
existing_inputs=existing_inputs,
existing_model=existing_model,
)
ComputeGAEMixIn.__init__(self)
ComputeAndClipGradsMixIn.__init__(self)
KLCoeffMixin.__init__(self, config)
ValueNetworkMixin.__init__(self, config)
# Create the `split` placeholder before initialize loss.
if self.framework == "tf":
self._loss_input_dict["split"] = tf1.placeholder(
tf.int32,
name="Meta-Update-Splitting",
shape=(
self.config["inner_adaptation_steps"] + 1,
self.config["num_workers"],
),
)
# Note: this is a bit ugly, but loss and optimizer initialization must
# happen after all the MixIns are initialized.
self.maybe_initialize_optimizer_and_loss()
@override(base)
def loss(
self,
model: Union[ModelV2, "tf.keras.Model"],
dist_class: Type[TFActionDistribution],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
logits, state = model(train_batch)
self.cur_lr = self.config["lr"]
if self.config["worker_index"]:
self.loss_obj = WorkerLoss(
dist_class=dist_class,
actions=train_batch[SampleBatch.ACTIONS],
curr_logits=logits,
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
advantages=train_batch[Postprocessing.ADVANTAGES],
value_fn=model.value_function(),
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
vf_preds=train_batch[SampleBatch.VF_PREDS],
cur_kl_coeff=0.0,
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"],
vf_clip_param=self.config["vf_clip_param"],
vf_loss_coeff=self.config["vf_loss_coeff"],
clip_loss=False,
)
else:
self.var_list = tf1.get_collection(
tf1.GraphKeys.TRAINABLE_VARIABLES, tf1.get_variable_scope().name
)
self.loss_obj = MAMLLoss(
model=model,
dist_class=dist_class,
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
advantages=train_batch[Postprocessing.ADVANTAGES],
actions=train_batch[SampleBatch.ACTIONS],
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
vf_preds=train_batch[SampleBatch.VF_PREDS],
cur_kl_coeff=self.kl_coeff,
policy_vars=self.var_list,
obs=train_batch[SampleBatch.CUR_OBS],
num_tasks=self.config["num_workers"],
split=train_batch["split"],
config=self.config,
inner_adaptation_steps=self.config["inner_adaptation_steps"],
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"],
vf_clip_param=self.config["vf_clip_param"],
vf_loss_coeff=self.config["vf_loss_coeff"],
use_gae=self.config["use_gae"],
)
return self.loss_obj.loss
@override(base)
def optimizer(
self,
) -> Union[
"tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]
]:
"""
Workers use simple SGD for inner adaptation
Meta-Policy uses Adam optimizer for meta-update
"""
if not self.config["worker_index"]:
return tf1.train.AdamOptimizer(learning_rate=self.config["lr"])
return tf1.train.GradientDescentOptimizer(
learning_rate=self.config["inner_lr"]
)
@override(base)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
if self.config["worker_index"]:
return {"worker_loss": self.loss_obj.loss}
else:
return {
"cur_kl_coeff": tf.cast(self.kl_coeff, tf.float64),
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"total_loss": self.loss_obj.loss,
"policy_loss": self.loss_obj.mean_policy_loss,
"vf_loss": self.loss_obj.mean_vf_loss,
"kl": self.loss_obj.mean_kl,
"inner_kl": self.loss_obj.mean_inner_kl,
"entropy": self.loss_obj.mean_entropy,
}
return MAMLTFPolicy
def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy, config)
KLCoeffMixin.__init__(policy, config)
# Create the `split` placeholder.
policy._loss_input_dict["split"] = tf1.placeholder(
tf.int32,
name="Meta-Update-Splitting",
shape=(
policy.config["inner_adaptation_steps"] + 1,
policy.config["num_workers"],
),
)
MAMLTFPolicy = build_tf_policy(
name="MAMLTFPolicy",
get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
loss_fn=maml_loss,
stats_fn=maml_stats,
optimizer_fn=maml_optimizer_fn,
extra_action_out_fn=vf_preds_fetches,
postprocess_fn=compute_gae_for_sample_batch,
compute_gradients_fn=compute_and_clip_gradients,
before_init=setup_config,
before_loss_init=setup_mixins,
mixins=[KLCoeffMixin],
)
MAMLDynamicTFPolicy = get_maml_tf_policy(DynamicTFPolicyV2)
MAMLEagerTFPolicy = get_maml_tf_policy(EagerTFPolicyV2)

View file

@ -1,17 +1,20 @@
import higher
import logging
from typing import Dict, List, Type, Union
import ray
from ray.rllib.evaluation.postprocessing import (
compute_gae_for_sample_batch,
Postprocessing,
)
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.agents.ppo.ppo_tf_policy import validate_config
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_mixins import ValueNetworkMixin
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.policy.torch_mixins import ComputeGAEMixIn, ValueNetworkMixin
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import apply_grad_clipping
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType
torch, nn = try_import_torch()
@ -145,9 +148,6 @@ class MAMLLoss(object):
vf_loss_coeff=1.0,
use_gae=True,
):
import higher
self.config = config
self.num_tasks = num_tasks
self.inner_adaptation_steps = inner_adaptation_steps
@ -266,86 +266,6 @@ class MAMLLoss(object):
return placeholder_list
def maml_loss(policy, model, dist_class, train_batch):
logits, state = model(train_batch)
policy.cur_lr = policy.config["lr"]
if policy.config["worker_index"]:
policy.loss_obj = WorkerLoss(
model=model,
dist_class=dist_class,
actions=train_batch[SampleBatch.ACTIONS],
curr_logits=logits,
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
advantages=train_batch[Postprocessing.ADVANTAGES],
value_fn=model.value_function(),
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
vf_preds=train_batch[SampleBatch.VF_PREDS],
cur_kl_coeff=0.0,
entropy_coeff=policy.config["entropy_coeff"],
clip_param=policy.config["clip_param"],
vf_clip_param=policy.config["vf_clip_param"],
vf_loss_coeff=policy.config["vf_loss_coeff"],
clip_loss=False,
)
else:
policy.var_list = model.named_parameters()
# `split` may not exist yet (during test-loss call), use a dummy value.
# Cannot use get here due to train_batch being a TrackingDict.
if "split" in train_batch:
split = train_batch["split"]
else:
split_shape = (
policy.config["inner_adaptation_steps"],
policy.config["num_workers"],
)
split_const = int(
train_batch["obs"].shape[0] // (split_shape[0] * split_shape[1])
)
split = torch.ones(split_shape, dtype=int) * split_const
policy.loss_obj = MAMLLoss(
model=model,
dist_class=dist_class,
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
advantages=train_batch[Postprocessing.ADVANTAGES],
actions=train_batch[SampleBatch.ACTIONS],
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
vf_preds=train_batch[SampleBatch.VF_PREDS],
cur_kl_coeff=policy.kl_coeff_val,
policy_vars=policy.var_list,
obs=train_batch[SampleBatch.CUR_OBS],
num_tasks=policy.config["num_workers"],
split=split,
config=policy.config,
inner_adaptation_steps=policy.config["inner_adaptation_steps"],
entropy_coeff=policy.config["entropy_coeff"],
clip_param=policy.config["clip_param"],
vf_clip_param=policy.config["vf_clip_param"],
vf_loss_coeff=policy.config["vf_loss_coeff"],
use_gae=policy.config["use_gae"],
meta_opt=policy.meta_opt,
)
return policy.loss_obj.loss
def maml_stats(policy, train_batch):
if policy.config["worker_index"]:
return {"worker_loss": policy.loss_obj.loss}
else:
return {
"cur_kl_coeff": policy.kl_coeff_val,
"cur_lr": policy.cur_lr,
"total_loss": policy.loss_obj.loss,
"policy_loss": policy.loss_obj.mean_policy_loss,
"vf_loss": policy.loss_obj.mean_vf_loss,
"kl_loss": policy.loss_obj.mean_kl_loss,
"inner_kl": policy.loss_obj.mean_inner_kl,
"entropy": policy.loss_obj.mean_entropy,
}
class KLCoeffMixin:
def __init__(self, config):
self.kl_coeff_val = (
@ -364,33 +284,141 @@ class KLCoeffMixin:
return self.kl_coeff_val
def maml_optimizer_fn(policy, config):
"""
Workers use simple SGD for inner adaptation
Meta-Policy uses Adam optimizer for meta-update
"""
if not config["worker_index"]:
policy.meta_opt = torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
return policy.meta_opt
return torch.optim.SGD(policy.model.parameters(), lr=config["inner_lr"])
class MAMLTorchPolicy(ComputeGAEMixIn, ValueNetworkMixin, KLCoeffMixin, TorchPolicyV2):
"""PyTorch policy class used with MAMLTrainer."""
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, **config)
validate_config(self, observation_space, action_space, config)
def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy, config)
KLCoeffMixin.__init__(policy, config)
TorchPolicyV2.__init__(
self,
observation_space,
action_space,
config,
max_seq_len=config["model"]["max_seq_len"],
)
ComputeGAEMixIn.__init__(self)
KLCoeffMixin.__init__(self, config)
ValueNetworkMixin.__init__(self, config)
MAMLTorchPolicy = build_policy_class(
name="MAMLTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
loss_fn=maml_loss,
stats_fn=maml_stats,
optimizer_fn=maml_optimizer_fn,
extra_action_out_fn=vf_preds_fetches,
postprocess_fn=compute_gae_for_sample_batch,
extra_grad_process_fn=apply_grad_clipping,
before_init=setup_config,
after_init=setup_mixins,
mixins=[KLCoeffMixin],
)
# TODO: Don't require users to call this manually.
self._initialize_loss_from_dummy_batch()
@override(TorchPolicyV2)
def loss(
self,
model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss function.
Args:
model: The Model to calculate the loss for.
dist_class: The action distr. class.
train_batch: The training data.
Returns:
The PPO loss tensor given the input batch.
"""
logits, state = model(train_batch)
self.cur_lr = self.config["lr"]
if self.config["worker_index"]:
self.loss_obj = WorkerLoss(
model=model,
dist_class=dist_class,
actions=train_batch[SampleBatch.ACTIONS],
curr_logits=logits,
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
advantages=train_batch[Postprocessing.ADVANTAGES],
value_fn=model.value_function(),
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
vf_preds=train_batch[SampleBatch.VF_PREDS],
cur_kl_coeff=0.0,
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"],
vf_clip_param=self.config["vf_clip_param"],
vf_loss_coeff=self.config["vf_loss_coeff"],
clip_loss=False,
)
else:
self.var_list = model.named_parameters()
# `split` may not exist yet (during test-loss call), use a dummy value.
# Cannot use get here due to train_batch being a TrackingDict.
if "split" in train_batch:
split = train_batch["split"]
else:
split_shape = (
self.config["inner_adaptation_steps"],
self.config["num_workers"],
)
split_const = int(
train_batch["obs"].shape[0] // (split_shape[0] * split_shape[1])
)
split = torch.ones(split_shape, dtype=int) * split_const
self.loss_obj = MAMLLoss(
model=model,
dist_class=dist_class,
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
advantages=train_batch[Postprocessing.ADVANTAGES],
actions=train_batch[SampleBatch.ACTIONS],
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
vf_preds=train_batch[SampleBatch.VF_PREDS],
cur_kl_coeff=self.kl_coeff_val,
policy_vars=self.var_list,
obs=train_batch[SampleBatch.CUR_OBS],
num_tasks=self.config["num_workers"],
split=split,
config=self.config,
inner_adaptation_steps=self.config["inner_adaptation_steps"],
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"],
vf_clip_param=self.config["vf_clip_param"],
vf_loss_coeff=self.config["vf_loss_coeff"],
use_gae=self.config["use_gae"],
meta_opt=self.meta_opt,
)
return self.loss_obj.loss
@override(TorchPolicyV2)
def optimizer(
self,
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
"""
Workers use simple SGD for inner adaptation
Meta-Policy uses Adam optimizer for meta-update
"""
if not self.config["worker_index"]:
self.meta_opt = torch.optim.Adam(
self.model.parameters(), lr=self.config["lr"]
)
return self.meta_opt
return torch.optim.SGD(self.model.parameters(), lr=self.config["inner_lr"])
@override(TorchPolicyV2)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
if self.config["worker_index"]:
return convert_to_numpy({"worker_loss": self.loss_obj.loss})
else:
return convert_to_numpy(
{
"cur_kl_coeff": self.kl_coeff_val,
"cur_lr": self.cur_lr,
"total_loss": self.loss_obj.loss,
"policy_loss": self.loss_obj.mean_policy_loss,
"vf_loss": self.loss_obj.mean_vf_loss,
"kl_loss": self.loss_obj.mean_kl_loss,
"inner_kl": self.loss_obj.mean_inner_kl,
"entropy": self.loss_obj.mean_entropy,
}
)
def extra_grad_process(
self, optimizer: "torch.optim.Optimizer", loss: TensorType
) -> Dict[str, TensorType]:
return apply_grad_clipping(self, optimizer, loss)

View file

@ -1,13 +1,17 @@
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.algorithms.marwil.marwil_tf_policy import (
MARWILDynamicTFPolicy,
MARWILEagerTFPolicy,
)
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
__all__ = [
"BCTrainer",
"BC_DEFAULT_CONFIG",
"DEFAULT_CONFIG",
"MARWILTFPolicy",
"MARWILDynamicTFPolicy",
"MARWILEagerTFPolicy",
"MARWILTorchPolicy",
"MARWILTrainer",
]

View file

@ -1,7 +1,6 @@
from typing import Type
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
@ -123,8 +122,16 @@ class MARWILTrainer(Trainer):
)
return MARWILTorchPolicy
elif config["framework"] == "tf":
from ray.rllib.algorithms.marwil.marwil_tf_policy import (
MARWILDynamicTFPolicy,
)
return MARWILDynamicTFPolicy
else:
return MARWILTFPolicy
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILEagerTFPolicy
return MARWILEagerTFPolicy
@override(Trainer)
def training_iteration(self) -> ResultDict:

View file

@ -1,100 +1,69 @@
import logging
import gym
from typing import Optional, Dict
from typing import Any, Dict, List, Optional, Type, Union
import ray
from ray.rllib.agents.ppo.ppo_tf_policy import compute_and_clip_gradients
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.framework import try_import_tf, get_variable
from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, PolicyID
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_mixins import ComputeAndClipGradsMixIn, ValueNetworkMixin
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, get_variable
from ray.rllib.utils.tf_utils import explained_variance
from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
class ValueNetworkMixin:
def __init__(
class PostprocessAdvantages:
"""Marwil's custom trajectory post-processing mixin."""
def __init__(self):
pass
def postprocess_trajectory(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
episode: Optional["Episode"] = None,
):
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
@make_tf_callable(self.get_session())
def value(**input_dict):
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0]
self._value = value
def postprocess_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode=None,
) -> SampleBatch:
"""Postprocesses a trajectory and returns the processed trajectory.
The trajectory contains only data from one episode and from one agent.
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
contain a truncated (at-the-end) episode, in case the
`config.rollout_fragment_length` was reached by the sampler.
- If `config.batch_mode=complete_episodes`, sample_batch will contain
exactly one episode (no matter how long).
New columns can be added to sample_batch and existing ones may be altered.
Args:
policy (Policy): The Policy used to generate the trajectory
(`sample_batch`)
sample_batch (SampleBatch): The SampleBatch to postprocess.
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated.
Returns:
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
"""
# Trajectory is actually complete -> last r=0.0.
if sample_batch[SampleBatch.DONES][-1]:
last_r = 0.0
# Trajectory has been truncated -> last r=VF estimate of last obs.
else:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
# Create an input dict according to the Model's requirements.
index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
input_dict = sample_batch.get_single_step_input_dict(
policy.model.view_requirements, index=index
sample_batch = super().postprocess_trajectory(
sample_batch, other_agent_batches, episode
)
last_r = policy._value(**input_dict)
# Adds the "advantages" (which in the case of MARWIL are simply the
# discounted cummulative rewards) to the SampleBatch.
return compute_advantages(
sample_batch,
last_r,
policy.config["gamma"],
# We just want the discounted cummulative rewards, so we won't need
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
use_gae=False,
use_critic=False,
)
# Trajectory is actually complete -> last r=0.0.
if sample_batch[SampleBatch.DONES][-1]:
last_r = 0.0
# Trajectory has been truncated -> last r=VF estimate of last obs.
else:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
# Create an input dict according to the Model's requirements.
index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
input_dict = sample_batch.get_single_step_input_dict(
self.model.view_requirements, index=index
)
last_r = self._value(**input_dict)
# Adds the "advantages" (which in the case of MARWIL are simply the
# discounted cummulative rewards) to the SampleBatch.
return compute_advantages(
sample_batch,
last_r,
self.config["gamma"],
# We just want the discounted cummulative rewards, so we won't need
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
use_gae=False,
use_critic=False,
)
class MARWILLoss:
@ -174,69 +143,100 @@ class MARWILLoss:
self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss
def marwil_loss(
policy: Policy,
model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch,
) -> TensorType:
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
value_estimates = model.value_function()
# We need this builder function because we want to share the same
# custom logics between TF1 dynamic and TF2 eager policies.
def get_marwil_tf_policy(base: type) -> type:
"""Construct a MARWILTFPolicy inheriting either dynamic or eager base policies.
policy.loss = MARWILLoss(
policy,
value_estimates,
action_dist,
train_batch,
policy.config["vf_coeff"],
policy.config["beta"],
)
Args:
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
return policy.loss.total_loss
Returns:
A TF Policy to be used with MAMLTrainer.
"""
class MARWILTFPolicy(
ComputeAndClipGradsMixIn, ValueNetworkMixin, PostprocessAdvantages, base
):
def __init__(
self,
obs_space,
action_space,
config,
existing_model=None,
existing_inputs=None,
):
# First thing first, enable eager execution if necessary.
base.enable_eager_execution_if_necessary()
config = dict(ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG, **config)
# Initialize base class.
base.__init__(
self,
obs_space,
action_space,
config,
existing_inputs=existing_inputs,
existing_model=existing_model,
)
ComputeAndClipGradsMixIn.__init__(self)
ValueNetworkMixin.__init__(self, config)
PostprocessAdvantages.__init__(self)
# Not needed for pure BC.
if config["beta"] != 0.0:
# Set up a tf-var for the moving avg (do this here to make it work
# with eager mode); "c^2" in the paper.
self._moving_average_sqd_adv_norm = get_variable(
config["moving_average_sqd_adv_norm_start"],
framework="tf",
tf_name="moving_average_of_advantage_norm",
trainable=False,
)
# Note: this is a bit ugly, but loss and optimizer initialization must
# happen after all the MixIns are initialized.
self.maybe_initialize_optimizer_and_loss()
@override(base)
def loss(
self,
model: Union[ModelV2, "tf.keras.Model"],
dist_class: Type[TFActionDistribution],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
value_estimates = model.value_function()
self.loss = MARWILLoss(
self,
value_estimates,
action_dist,
train_batch,
self.config["vf_coeff"],
self.config["beta"],
)
return self.loss.total_loss
@override(base)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
stats = {
"policy_loss": self.loss.p_loss,
"total_loss": self.loss.total_loss,
}
if self.config["beta"] != 0.0:
stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm
stats["vf_explained_var"] = self.loss.explained_variance
stats["vf_loss"] = self.loss.v_loss
return stats
return MARWILTFPolicy
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
stats = {
"policy_loss": policy.loss.p_loss,
"total_loss": policy.loss.total_loss,
}
if policy.config["beta"] != 0.0:
stats["moving_average_sqd_adv_norm"] = policy._moving_average_sqd_adv_norm
stats["vf_explained_var"] = policy.loss.explained_variance
stats["vf_loss"] = policy.loss.v_loss
return stats
def setup_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
) -> None:
# Setup Value branch of our NN.
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
# Not needed for pure BC.
if policy.config["beta"] != 0.0:
# Set up a tf-var for the moving avg (do this here to make it work
# with eager mode); "c^2" in the paper.
policy._moving_average_sqd_adv_norm = get_variable(
policy.config["moving_average_sqd_adv_norm_start"],
framework="tf",
tf_name="moving_average_of_advantage_norm",
trainable=False,
)
MARWILTFPolicy = build_tf_policy(
name="MARWILTFPolicy",
get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
loss_fn=marwil_loss,
stats_fn=stats,
postprocess_fn=postprocess_advantages,
before_loss_init=setup_mixins,
compute_gradients_fn=compute_and_clip_gradients,
mixins=[ValueNetworkMixin],
)
MARWILDynamicTFPolicy = get_marwil_tf_policy(DynamicTFPolicyV2)
MARWILEagerTFPolicy = get_marwil_tf_policy(EagerTFPolicyV2)

View file

@ -1,122 +1,124 @@
import gym
from typing import Dict
from typing import Dict, List, Type, Union
import ray
from ray.rllib.algorithms.marwil.marwil_tf_policy import postprocess_advantages
from ray.rllib.algorithms.marwil.marwil_tf_policy import PostprocessAdvantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_mixins import ValueNetworkMixin
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import apply_grad_clipping, explained_variance
from ray.rllib.utils.typing import TrainerConfigDict, TensorType
from ray.rllib.policy.policy import Policy
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.typing import TensorType
torch, _ = try_import_torch()
def marwil_loss(
policy: Policy,
model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch,
) -> TensorType:
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
actions = train_batch[SampleBatch.ACTIONS]
# log\pi_\theta(a|s)
logprobs = action_dist.logp(actions)
class MARWILTorchPolicy(ValueNetworkMixin, PostprocessAdvantages, TorchPolicyV2):
"""PyTorch policy class used with MarwilTrainer."""
# Advantage estimation.
if policy.config["beta"] != 0.0:
cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
state_values = model.value_function()
adv = cumulative_rewards - state_values
adv_squared_mean = torch.mean(torch.pow(adv, 2.0))
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG, **config)
explained_var = explained_variance(cumulative_rewards, state_values)
policy.explained_variance = torch.mean(explained_var)
# Policy loss.
# Update averaged advantage norm.
rate = policy.config["moving_average_sqd_adv_norm_update_rate"]
policy._moving_average_sqd_adv_norm.add_(
rate * (adv_squared_mean - policy._moving_average_sqd_adv_norm)
TorchPolicyV2.__init__(
self,
observation_space,
action_space,
config,
max_seq_len=config["model"]["max_seq_len"],
)
# Exponentially weighted advantages.
exp_advs = torch.exp(
policy.config["beta"]
* (adv / (1e-8 + torch.pow(policy._moving_average_sqd_adv_norm, 0.5)))
).detach()
# Value loss.
policy.v_loss = 0.5 * adv_squared_mean
else:
# Policy loss (simple BC loss term).
exp_advs = 1.0
# Value loss.
policy.v_loss = 0.0
# logprob loss alone tends to push action distributions to
# have very low entropy, resulting in worse performance for
# unfamiliar situations.
# A scaled logstd loss term encourages stochasticity, thus
# alleviate the problem to some extent.
logstd_coeff = policy.config["bc_logstd_coeff"]
if logstd_coeff > 0.0:
logstds = torch.mean(action_dist.log_std, dim=1)
else:
logstds = 0.0
ValueNetworkMixin.__init__(self, config)
PostprocessAdvantages.__init__(self)
policy.p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds))
# Not needed for pure BC.
if config["beta"] != 0.0:
# Set up a torch-var for the squared moving avg. advantage norm.
self._moving_average_sqd_adv_norm = torch.tensor(
[config["moving_average_sqd_adv_norm_start"]],
dtype=torch.float32,
requires_grad=False,
).to(self.device)
# Combine both losses.
policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * policy.v_loss
# TODO: Don't require users to call this manually.
self._initialize_loss_from_dummy_batch()
return policy.total_loss
@override(TorchPolicyV2)
def loss(
self,
model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
actions = train_batch[SampleBatch.ACTIONS]
# log\pi_\theta(a|s)
logprobs = action_dist.logp(actions)
# Advantage estimation.
if self.config["beta"] != 0.0:
cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
state_values = model.value_function()
adv = cumulative_rewards - state_values
adv_squared_mean = torch.mean(torch.pow(adv, 2.0))
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
stats = {
"policy_loss": policy.p_loss,
"total_loss": policy.total_loss,
}
if policy.config["beta"] != 0.0:
stats["moving_average_sqd_adv_norm"] = policy._moving_average_sqd_adv_norm
stats["vf_explained_var"] = policy.explained_variance
stats["vf_loss"] = policy.v_loss
explained_var = explained_variance(cumulative_rewards, state_values)
self.explained_variance = torch.mean(explained_var)
return stats
# Policy loss.
# Update averaged advantage norm.
rate = self.config["moving_average_sqd_adv_norm_update_rate"]
self._moving_average_sqd_adv_norm.add_(
rate * (adv_squared_mean - self._moving_average_sqd_adv_norm)
)
# Exponentially weighted advantages.
exp_advs = torch.exp(
self.config["beta"]
* (adv / (1e-8 + torch.pow(self._moving_average_sqd_adv_norm, 0.5)))
).detach()
# Value loss.
self.v_loss = 0.5 * adv_squared_mean
else:
# Policy loss (simple BC loss term).
exp_advs = 1.0
# Value loss.
self.v_loss = 0.0
# logprob loss alone tends to push action distributions to
# have very low entropy, resulting in worse performance for
# unfamiliar situations.
# A scaled logstd loss term encourages stochasticity, thus
# alleviate the problem to some extent.
logstd_coeff = self.config["bc_logstd_coeff"]
if logstd_coeff > 0.0:
logstds = torch.mean(action_dist.log_std, dim=1)
else:
logstds = 0.0
def setup_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
) -> None:
# Setup Value branch of our NN.
ValueNetworkMixin.__init__(policy, config)
self.p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds))
# Not needed for pure BC.
if policy.config["beta"] != 0.0:
# Set up a torch-var for the squared moving avg. advantage norm.
policy._moving_average_sqd_adv_norm = torch.tensor(
[policy.config["moving_average_sqd_adv_norm_start"]],
dtype=torch.float32,
requires_grad=False,
).to(policy.device)
# Combine both losses.
self.total_loss = self.p_loss + self.config["vf_coeff"] * self.v_loss
return self.total_loss
MARWILTorchPolicy = build_policy_class(
name="MARWILTorchPolicy",
framework="torch",
loss_fn=marwil_loss,
get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
stats_fn=stats,
postprocess_fn=postprocess_advantages,
extra_grad_process_fn=apply_grad_clipping,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin],
)
@override(TorchPolicyV2)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
stats = {
"policy_loss": self.p_loss,
"total_loss": self.total_loss,
}
if self.config["beta"] != 0.0:
stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm
stats["vf_explained_var"] = self.explained_variance
stats["vf_loss"] = self.v_loss
return convert_to_numpy(stats)
def extra_grad_process(
self, optimizer: "torch.optim.Optimizer", loss: TensorType
) -> Dict[str, TensorType]:
return apply_grad_clipping(self, optimizer, loss)

View file

@ -5,6 +5,8 @@ import unittest
import ray
import ray.rllib.algorithms.marwil as marwil
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILEagerTFPolicy
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.offline import JsonReader
from ray.rllib.utils.framework import try_import_tf, try_import_torch
@ -182,9 +184,7 @@ class TestMARWIL(unittest.TestCase):
batch.set_get_interceptor(None)
postprocessed_batch = policy.postprocess_trajectory(batch)
loss_func = (
marwil.marwil_tf_policy.marwil_loss
if fw != "torch"
else marwil.marwil_torch_policy.marwil_loss
MARWILEagerTFPolicy.loss if fw != "torch" else MARWILTorchPolicy.loss
)
if fw != "tf":
policy._lazy_tensor_dict(postprocessed_batch)

View file

@ -4,7 +4,6 @@ from typing import List, Type
import ray
from ray.rllib.agents import with_common_config
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
from ray.rllib.agents.trainer import Trainer
@ -374,6 +373,8 @@ class MBMPOTrainer(Trainer):
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
return MBMPOTorchPolicy
@staticmethod

View file

@ -1,135 +1,90 @@
import gym
from gym.spaces import Box, Discrete
import logging
from typing import Tuple, Type
import ray
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
from ray.rllib.algorithms.maml.maml_torch_policy import (
setup_mixins,
maml_loss,
maml_stats,
maml_optimizer_fn,
KLCoeffMixin,
)
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import apply_grad_clipping
from ray.rllib.utils.typing import TrainerConfigDict
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
def validate_spaces(
policy: Policy,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
) -> None:
"""Validates the observation- and action spaces used for the Policy.
class MBMPOTorchPolicy(MAMLTorchPolicy):
def __init__(self, observation_space, action_space, config):
# Validate spaces.
# Only support single Box or single Discrete spaces.
if not isinstance(action_space, (Box, Discrete)):
raise UnsupportedSpaceException(
"Action space ({}) of {} is not supported for "
"MB-MPO. Must be [Box|Discrete].".format(action_space, self)
)
# If Box, make sure it's a 1D vector space.
elif isinstance(action_space, Box) and len(action_space.shape) > 1:
raise UnsupportedSpaceException(
"Action space ({}) of {} has multiple dimensions "
"{}. ".format(action_space, self, action_space.shape)
+ "Consider reshaping this into a single dimension Box space "
"or using the multi-agent API."
)
Args:
policy (Policy): The policy, whose spaces are being validated.
observation_space (gym.spaces.Space): The observation space to
validate.
action_space (gym.spaces.Space): The action space to validate.
config (TrainerConfigDict): The Policy's config dict.
config = dict(ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG, **config)
super().__init__(observation_space, action_space, config)
Raises:
UnsupportedSpaceException: If one of the spaces is not supported.
"""
# Only support single Box or single Discrete spaces.
if not isinstance(action_space, (Box, Discrete)):
raise UnsupportedSpaceException(
"Action space ({}) of {} is not supported for "
"MB-MPO. Must be [Box|Discrete].".format(action_space, policy)
)
# If Box, make sure it's a 1D vector space.
elif isinstance(action_space, Box) and len(action_space.shape) > 1:
raise UnsupportedSpaceException(
"Action space ({}) of {} has multiple dimensions "
"{}. ".format(action_space, policy, action_space.shape)
+ "Consider reshaping this into a single dimension Box space "
"or using the multi-agent API."
def make_model_and_action_dist(
self,
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
"""Constructs the necessary ModelV2 and action dist class for the Policy.
Args:
obs_space (gym.spaces.Space): The observation space.
action_space (gym.spaces.Space): The action space.
config (TrainerConfigDict): The SAC trainer's config dict.
Returns:
ModelV2: The ModelV2 to be used by the Policy. Note: An additional
target model will be created in this function and assigned to
`policy.target_model`.
"""
# Get the output distribution class for predicting rewards and next-obs.
self.distr_cls_next_obs, num_outputs = ModelCatalog.get_action_dist(
self.observation_space,
self.config,
dist_type="deterministic",
framework="torch",
)
# Build one dynamics model if we are a Worker.
# If we are the main MAML learner, build n (num_workers) dynamics Models
# for being able to create checkpoints for the current state of training.
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.dynamics_model = ModelCatalog.get_model_v2(
self.observation_space,
self.action_space,
num_outputs=num_outputs,
model_config=self.config["dynamics_model"],
framework="torch",
name="dynamics_ensemble",
).to(device)
def make_model_and_action_dist(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
"""Constructs the necessary ModelV2 and action dist class for the Policy.
action_dist, num_outputs = ModelCatalog.get_action_dist(
self.action_space, self.config, framework="torch"
)
# Create the pi-model and register it with the Policy.
self.pi = ModelCatalog.get_model_v2(
self.observation_space,
self.action_space,
num_outputs=num_outputs,
model_config=self.config["model"],
framework="torch",
name="policy_model",
)
Args:
policy (Policy): The TFPolicy that will use the models.
obs_space (gym.spaces.Space): The observation space.
action_space (gym.spaces.Space): The action space.
config (TrainerConfigDict): The SAC trainer's config dict.
Returns:
ModelV2: The ModelV2 to be used by the Policy. Note: An additional
target model will be created in this function and assigned to
`policy.target_model`.
"""
# Get the output distribution class for predicting rewards and next-obs.
policy.distr_cls_next_obs, num_outputs = ModelCatalog.get_action_dist(
obs_space, config, dist_type="deterministic", framework="torch"
)
# Build one dynamics model if we are a Worker.
# If we are the main MAML learner, build n (num_workers) dynamics Models
# for being able to create checkpoints for the current state of training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
policy.dynamics_model = ModelCatalog.get_model_v2(
obs_space,
action_space,
num_outputs=num_outputs,
model_config=config["dynamics_model"],
framework="torch",
name="dynamics_ensemble",
).to(device)
action_dist, num_outputs = ModelCatalog.get_action_dist(
action_space, config, framework="torch"
)
# Create the pi-model and register it with the Policy.
policy.pi = ModelCatalog.get_model_v2(
obs_space,
action_space,
num_outputs=num_outputs,
model_config=config["model"],
framework="torch",
name="policy_model",
)
return policy.pi, action_dist
# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
MBMPOTorchPolicy = build_policy_class(
name="MBMPOTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG,
make_model_and_action_dist=make_model_and_action_dist,
loss_fn=maml_loss,
stats_fn=maml_stats,
optimizer_fn=maml_optimizer_fn,
extra_action_out_fn=vf_preds_fetches,
postprocess_fn=compute_gae_for_sample_batch,
extra_grad_process_fn=apply_grad_clipping,
before_init=setup_config,
after_init=setup_mixins,
mixins=[KLCoeffMixin],
)
return self.pi, action_dist

View file

@ -29,6 +29,8 @@ from ray.rllib.utils.annotations import (
DeveloperAPI,
ExperimentalAPI,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
is_overridden,
)
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.exploration.exploration import Exploration
@ -155,6 +157,20 @@ class Policy(metaclass=ABCMeta):
# Child classes may set this.
self.dist_class: Optional[Type] = None
# Initialize view requirements.
self.init_view_requirements()
# Whether the Model's initial state (method) has been added
# automatically based on the given view requirements of the model.
self._model_init_state_automatically_added = False
@DeveloperAPI
def init_view_requirements(self):
"""Maximal view requirements dict for `learn_on_batch()` and
`compute_actions` calls.
Specific policies can override this function to provide custom
list of view requirements.
"""
# Maximal view requirements dict for `learn_on_batch()` and
# `compute_actions` calls.
# View requirements will be automatically filtered out later based
@ -167,9 +183,6 @@ class Policy(metaclass=ABCMeta):
for k, v in view_reqs.items():
if k not in self.view_requirements:
self.view_requirements[k] = v
# Whether the Model's initial state (method) has been added
# automatically based on the given view requirements of the model.
self._model_init_state_automatically_added = False
@DeveloperAPI
def compute_single_action(
@ -413,6 +426,7 @@ class Policy(metaclass=ABCMeta):
raise NotImplementedError
@DeveloperAPI
@OverrideToImplementCustomLogic_CallToSuperRecommended
def postprocess_trajectory(
self,
sample_batch: SampleBatch,
@ -739,9 +753,7 @@ class Policy(metaclass=ABCMeta):
# steps).
# Make sure, we keep global_timestep as a Tensor for tf-eager
# (leads to memory leaks if not doing so).
from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
if self.framework in ["tf2", "tfe"] and isinstance(self, EagerTFPolicy):
if self.framework in ["tfe", "tf2"]:
self.global_timestep.assign(global_vars["timestep"])
else:
self.global_timestep = global_vars["timestep"]
@ -952,11 +964,19 @@ class Policy(metaclass=ABCMeta):
train_batch[SampleBatch.SEQ_LENS] = seq_lens
train_batch.count = self._dummy_batch.count
# Call the loss function, if it exists.
# TODO(jungong) : clean up after all agents get migrated.
# We should simply do self.loss(...) here.
if self._loss is not None:
self._loss(self, self.model, self.dist_class, train_batch)
elif is_overridden(self.loss):
self.loss(self.model, self.dist_class, train_batch)
# Call the stats fn, if given.
# TODO(jungong) : clean up after all agents get migrated.
# We should simply do self.stats_fn(train_batch) here.
if stats_fn is not None:
stats_fn(self, train_batch)
if hasattr(self, "stats_fn"):
self.stats_fn(train_batch)
# Re-enable tracing.
self._no_tracing = False
@ -974,7 +994,7 @@ class Policy(metaclass=ABCMeta):
self.view_requirements[key] = ViewRequirement(
used_for_compute_actions=False
)
if self._loss:
if self._loss or is_overridden(self.loss):
# Tag those only needed for post-processing (with some
# exceptions).
for key in self._dummy_batch.accessed_keys:

View file

@ -150,7 +150,7 @@ class TFPolicy(Policy):
super().__init__(observation_space, action_space, config)
# Get devices to build the graph on.
worker_idx = self.config.get("worker_index", 0)
worker_idx = config.get("worker_index", 0)
if not config["_fake_gpus"] and ray.worker._mode() == ray.worker.LOCAL_MODE:
num_gpus = 0
elif worker_idx == 0:
@ -237,7 +237,7 @@ class TFPolicy(Policy):
self._action_input = action_input # For logp calculations.
self._dist_inputs = dist_inputs
self.dist_class = dist_class
self._cached_extra_action_out = None
self._state_inputs = state_inputs or []
self._state_outputs = state_outputs or []
self._seq_lens = seq_lens
@ -784,6 +784,16 @@ class TFPolicy(Policy):
@DeveloperAPI
def extra_compute_action_fetches(self) -> Dict[str, TensorType]:
# Cache graph fetches for action computation for better
# performance.
# This function is called every time the static graph is run
# to compute actions.
if not self._cached_extra_action_out:
self._cached_extra_action_out = self.extra_action_out_fn()
return self._cached_extra_action_out
@DeveloperAPI
def extra_action_out_fn(self) -> Dict[str, TensorType]:
"""Extra values to fetch and return from compute_actions().
By default we return action probability/log-likelihood info

View file

@ -230,7 +230,7 @@ class TorchPolicyV2(Policy):
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss for Proximal Policy Objective.
"""Constructs the loss function.
Args:
model: The Model to calculate the loss for.
@ -396,20 +396,6 @@ class TorchPolicyV2(Policy):
"""
return {}
@DeveloperAPI
@OverrideToImplementCustomLogic_CallToSuperRecommended
def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Return dict of extra grad info.
Args:
train_batch: The training batch for which to produce
extra grad info for.
Returns:
The info dict carrying grad info per str key.
"""
return {}
@override(Policy)
@DeveloperAPI
@OverrideToImplementCustomLogic_CallToSuperRecommended
@ -418,8 +404,28 @@ class TorchPolicyV2(Policy):
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
episode: Optional["Episode"] = None,
):
"""Additional custom postprocessing of SampleBatch."""
) -> SampleBatch:
"""Postprocesses a trajectory and returns the processed trajectory.
The trajectory contains only data from one episode and from one agent.
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
contain a truncated (at-the-end) episode, in case the
`config.rollout_fragment_length` was reached by the sampler.
- If `config.batch_mode=complete_episodes`, sample_batch will contain
exactly one episode (no matter how long).
New columns can be added to sample_batch and existing ones may be altered.
Args:
sample_batch (SampleBatch): The SampleBatch to postprocess.
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated.
Returns:
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
"""
return sample_batch
@DeveloperAPI
@ -775,7 +781,7 @@ class TorchPolicyV2(Policy):
for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
batch_fetches[f"tower_{i}"].update(
{
LEARNER_STATS_KEY: self.extra_grad_info(batch),
LEARNER_STATS_KEY: self.stats_fn(batch),
"model": model.metrics(),
}
)
@ -810,7 +816,7 @@ class TorchPolicyV2(Policy):
all_grads, grad_info = tower_outputs[0]
grad_info["allreduce_latency"] /= len(self._optimizers)
grad_info.update(self.extra_grad_info(postprocessed_batch))
grad_info.update(self.stats_fn(postprocessed_batch))
fetches = self.extra_compute_grad_fetches()