mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Migrate MAML, MB-MPO, MARWIL, and BC to use Policy sub-classing implementation. (#24914)
This commit is contained in:
parent
a2c8fe2101
commit
d5a6d46049
16 changed files with 719 additions and 626 deletions
|
@ -1,13 +1,17 @@
|
|||
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import (
|
||||
MARWILDynamicTFPolicy,
|
||||
MARWILEagerTFPolicy,
|
||||
)
|
||||
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"BCTrainer",
|
||||
"BC_DEFAULT_CONFIG",
|
||||
"DEFAULT_CONFIG",
|
||||
"MARWILTFPolicy",
|
||||
"MARWILDynamicTFPolicy",
|
||||
"MARWILEagerTFPolicy",
|
||||
"MARWILTorchPolicy",
|
||||
"MARWILTrainer",
|
||||
]
|
||||
|
|
|
@ -241,7 +241,7 @@ def compute_and_clip_gradients(
|
|||
return grads_and_vars
|
||||
|
||||
|
||||
def setup_config(
|
||||
def validate_config(
|
||||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
|
@ -324,7 +324,7 @@ PPOTFPolicy = build_tf_policy(
|
|||
stats_fn=kl_and_loss_stats,
|
||||
compute_gradients_fn=compute_and_clip_gradients,
|
||||
extra_action_out_fn=vf_preds_fetches,
|
||||
before_init=setup_config,
|
||||
before_init=validate_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[
|
||||
LearningRateSchedule,
|
||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import validate_config
|
||||
from ray.rllib.evaluation.postprocessing import (
|
||||
compute_gae_for_sample_batch,
|
||||
Postprocessing,
|
||||
|
@ -43,7 +43,7 @@ class PPOTorchPolicy(
|
|||
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
|
||||
setup_config(self, observation_space, action_space, config)
|
||||
validate_config(self, observation_space, action_space, config)
|
||||
|
||||
TorchPolicy.__init__(
|
||||
self,
|
||||
|
|
|
@ -4,8 +4,6 @@ from typing import Type
|
|||
|
||||
from ray.rllib.utils.sgd import standardized
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTFPolicy
|
||||
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
|
@ -199,9 +197,17 @@ class MAMLTrainer(Trainer):
|
|||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
|
||||
|
||||
return MAMLTorchPolicy
|
||||
elif config["framework"] == "tf":
|
||||
from ray.rllib.algorithms.maml.maml_tf_policy import MAMLDynamicTFPolicy
|
||||
|
||||
return MAMLDynamicTFPolicy
|
||||
else:
|
||||
return MAMLTFPolicy
|
||||
from ray.rllib.algorithms.maml.maml_tf_policy import MAMLEagerTFPolicy
|
||||
|
||||
return MAMLEagerTFPolicy
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
|
|
|
@ -1,20 +1,23 @@
|
|||
import logging
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import (
|
||||
vf_preds_fetches,
|
||||
compute_and_clip_gradients,
|
||||
setup_config,
|
||||
)
|
||||
from ray.rllib.evaluation.postprocessing import (
|
||||
compute_gae_for_sample_batch,
|
||||
Postprocessing,
|
||||
)
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import validate_config
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
|
||||
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_mixins import ValueNetworkMixin
|
||||
from ray.rllib.policy.tf_mixins import (
|
||||
ComputeAndClipGradsMixIn,
|
||||
ComputeGAEMixIn,
|
||||
ValueNetworkMixin,
|
||||
)
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
@ -326,72 +329,6 @@ class MAMLLoss(object):
|
|||
return placeholder_list
|
||||
|
||||
|
||||
def maml_loss(policy, model, dist_class, train_batch):
|
||||
logits, state = model(train_batch)
|
||||
policy.cur_lr = policy.config["lr"]
|
||||
|
||||
if policy.config["worker_index"]:
|
||||
policy.loss_obj = WorkerLoss(
|
||||
dist_class=dist_class,
|
||||
actions=train_batch[SampleBatch.ACTIONS],
|
||||
curr_logits=logits,
|
||||
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
advantages=train_batch[Postprocessing.ADVANTAGES],
|
||||
value_fn=model.value_function(),
|
||||
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
|
||||
vf_preds=train_batch[SampleBatch.VF_PREDS],
|
||||
cur_kl_coeff=0.0,
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_param=policy.config["clip_param"],
|
||||
vf_clip_param=policy.config["vf_clip_param"],
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
clip_loss=False,
|
||||
)
|
||||
else:
|
||||
policy.var_list = tf1.get_collection(
|
||||
tf1.GraphKeys.TRAINABLE_VARIABLES, tf1.get_variable_scope().name
|
||||
)
|
||||
policy.loss_obj = MAMLLoss(
|
||||
model=model,
|
||||
dist_class=dist_class,
|
||||
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
|
||||
advantages=train_batch[Postprocessing.ADVANTAGES],
|
||||
actions=train_batch[SampleBatch.ACTIONS],
|
||||
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
vf_preds=train_batch[SampleBatch.VF_PREDS],
|
||||
cur_kl_coeff=policy.kl_coeff,
|
||||
policy_vars=policy.var_list,
|
||||
obs=train_batch[SampleBatch.CUR_OBS],
|
||||
num_tasks=policy.config["num_workers"],
|
||||
split=train_batch["split"],
|
||||
config=policy.config,
|
||||
inner_adaptation_steps=policy.config["inner_adaptation_steps"],
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_param=policy.config["clip_param"],
|
||||
vf_clip_param=policy.config["vf_clip_param"],
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
use_gae=policy.config["use_gae"],
|
||||
)
|
||||
|
||||
return policy.loss_obj.loss
|
||||
|
||||
|
||||
def maml_stats(policy, train_batch):
|
||||
if policy.config["worker_index"]:
|
||||
return {"worker_loss": policy.loss_obj.loss}
|
||||
else:
|
||||
return {
|
||||
"cur_kl_coeff": tf.cast(policy.kl_coeff, tf.float64),
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
"total_loss": policy.loss_obj.loss,
|
||||
"policy_loss": policy.loss_obj.mean_policy_loss,
|
||||
"vf_loss": policy.loss_obj.mean_vf_loss,
|
||||
"kl": policy.loss_obj.mean_kl,
|
||||
"inner_kl": policy.loss_obj.mean_inner_kl,
|
||||
"entropy": policy.loss_obj.mean_entropy,
|
||||
}
|
||||
|
||||
|
||||
class KLCoeffMixin:
|
||||
def __init__(self, config):
|
||||
self.kl_coeff_val = [config["kl_coeff"]] * config["inner_adaptation_steps"]
|
||||
|
@ -415,41 +352,154 @@ class KLCoeffMixin:
|
|||
return self.kl_coeff_val
|
||||
|
||||
|
||||
def maml_optimizer_fn(policy, config):
|
||||
# We need this builder function because we want to share the same
|
||||
# custom logics between TF1 dynamic and TF2 eager policies.
|
||||
def get_maml_tf_policy(base: type) -> type:
|
||||
"""Construct a MAMLTFPolicy inheriting either dynamic or eager base policies.
|
||||
|
||||
Args:
|
||||
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
|
||||
|
||||
Returns:
|
||||
A TF Policy to be used with MAMLTrainer.
|
||||
"""
|
||||
Workers use simple SGD for inner adaptation
|
||||
Meta-Policy uses Adam optimizer for meta-update
|
||||
"""
|
||||
if not config["worker_index"]:
|
||||
return tf1.train.AdamOptimizer(learning_rate=config["lr"])
|
||||
return tf1.train.GradientDescentOptimizer(learning_rate=config["inner_lr"])
|
||||
|
||||
class MAMLTFPolicy(
|
||||
ComputeGAEMixIn, ComputeAndClipGradsMixIn, KLCoeffMixin, ValueNetworkMixin, base
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_model=None,
|
||||
existing_inputs=None,
|
||||
):
|
||||
# First thing first, enable eager execution if necessary.
|
||||
base.enable_eager_execution_if_necessary()
|
||||
|
||||
config = dict(ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, **config)
|
||||
validate_config(self, obs_space, action_space, config)
|
||||
|
||||
# Initialize base class.
|
||||
base.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_inputs=existing_inputs,
|
||||
existing_model=existing_model,
|
||||
)
|
||||
|
||||
ComputeGAEMixIn.__init__(self)
|
||||
ComputeAndClipGradsMixIn.__init__(self)
|
||||
KLCoeffMixin.__init__(self, config)
|
||||
ValueNetworkMixin.__init__(self, config)
|
||||
|
||||
# Create the `split` placeholder before initialize loss.
|
||||
if self.framework == "tf":
|
||||
self._loss_input_dict["split"] = tf1.placeholder(
|
||||
tf.int32,
|
||||
name="Meta-Update-Splitting",
|
||||
shape=(
|
||||
self.config["inner_adaptation_steps"] + 1,
|
||||
self.config["num_workers"],
|
||||
),
|
||||
)
|
||||
|
||||
# Note: this is a bit ugly, but loss and optimizer initialization must
|
||||
# happen after all the MixIns are initialized.
|
||||
self.maybe_initialize_optimizer_and_loss()
|
||||
|
||||
@override(base)
|
||||
def loss(
|
||||
self,
|
||||
model: Union[ModelV2, "tf.keras.Model"],
|
||||
dist_class: Type[TFActionDistribution],
|
||||
train_batch: SampleBatch,
|
||||
) -> Union[TensorType, List[TensorType]]:
|
||||
logits, state = model(train_batch)
|
||||
self.cur_lr = self.config["lr"]
|
||||
|
||||
if self.config["worker_index"]:
|
||||
self.loss_obj = WorkerLoss(
|
||||
dist_class=dist_class,
|
||||
actions=train_batch[SampleBatch.ACTIONS],
|
||||
curr_logits=logits,
|
||||
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
advantages=train_batch[Postprocessing.ADVANTAGES],
|
||||
value_fn=model.value_function(),
|
||||
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
|
||||
vf_preds=train_batch[SampleBatch.VF_PREDS],
|
||||
cur_kl_coeff=0.0,
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_param=self.config["clip_param"],
|
||||
vf_clip_param=self.config["vf_clip_param"],
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
clip_loss=False,
|
||||
)
|
||||
else:
|
||||
self.var_list = tf1.get_collection(
|
||||
tf1.GraphKeys.TRAINABLE_VARIABLES, tf1.get_variable_scope().name
|
||||
)
|
||||
self.loss_obj = MAMLLoss(
|
||||
model=model,
|
||||
dist_class=dist_class,
|
||||
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
|
||||
advantages=train_batch[Postprocessing.ADVANTAGES],
|
||||
actions=train_batch[SampleBatch.ACTIONS],
|
||||
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
vf_preds=train_batch[SampleBatch.VF_PREDS],
|
||||
cur_kl_coeff=self.kl_coeff,
|
||||
policy_vars=self.var_list,
|
||||
obs=train_batch[SampleBatch.CUR_OBS],
|
||||
num_tasks=self.config["num_workers"],
|
||||
split=train_batch["split"],
|
||||
config=self.config,
|
||||
inner_adaptation_steps=self.config["inner_adaptation_steps"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_param=self.config["clip_param"],
|
||||
vf_clip_param=self.config["vf_clip_param"],
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
use_gae=self.config["use_gae"],
|
||||
)
|
||||
|
||||
return self.loss_obj.loss
|
||||
|
||||
@override(base)
|
||||
def optimizer(
|
||||
self,
|
||||
) -> Union[
|
||||
"tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]
|
||||
]:
|
||||
"""
|
||||
Workers use simple SGD for inner adaptation
|
||||
Meta-Policy uses Adam optimizer for meta-update
|
||||
"""
|
||||
if not self.config["worker_index"]:
|
||||
return tf1.train.AdamOptimizer(learning_rate=self.config["lr"])
|
||||
return tf1.train.GradientDescentOptimizer(
|
||||
learning_rate=self.config["inner_lr"]
|
||||
)
|
||||
|
||||
@override(base)
|
||||
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
if self.config["worker_index"]:
|
||||
return {"worker_loss": self.loss_obj.loss}
|
||||
else:
|
||||
return {
|
||||
"cur_kl_coeff": tf.cast(self.kl_coeff, tf.float64),
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"total_loss": self.loss_obj.loss,
|
||||
"policy_loss": self.loss_obj.mean_policy_loss,
|
||||
"vf_loss": self.loss_obj.mean_vf_loss,
|
||||
"kl": self.loss_obj.mean_kl,
|
||||
"inner_kl": self.loss_obj.mean_inner_kl,
|
||||
"entropy": self.loss_obj.mean_entropy,
|
||||
}
|
||||
|
||||
return MAMLTFPolicy
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
ValueNetworkMixin.__init__(policy, config)
|
||||
KLCoeffMixin.__init__(policy, config)
|
||||
|
||||
# Create the `split` placeholder.
|
||||
policy._loss_input_dict["split"] = tf1.placeholder(
|
||||
tf.int32,
|
||||
name="Meta-Update-Splitting",
|
||||
shape=(
|
||||
policy.config["inner_adaptation_steps"] + 1,
|
||||
policy.config["num_workers"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
MAMLTFPolicy = build_tf_policy(
|
||||
name="MAMLTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
|
||||
loss_fn=maml_loss,
|
||||
stats_fn=maml_stats,
|
||||
optimizer_fn=maml_optimizer_fn,
|
||||
extra_action_out_fn=vf_preds_fetches,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
compute_gradients_fn=compute_and_clip_gradients,
|
||||
before_init=setup_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[KLCoeffMixin],
|
||||
)
|
||||
MAMLDynamicTFPolicy = get_maml_tf_policy(DynamicTFPolicyV2)
|
||||
MAMLEagerTFPolicy = get_maml_tf_policy(EagerTFPolicyV2)
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
import higher
|
||||
import logging
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import (
|
||||
compute_gae_for_sample_batch,
|
||||
Postprocessing,
|
||||
)
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import validate_config
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_mixins import ValueNetworkMixin
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||
from ray.rllib.policy.torch_mixins import ComputeGAEMixIn, ValueNetworkMixin
|
||||
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
@ -145,9 +148,6 @@ class MAMLLoss(object):
|
|||
vf_loss_coeff=1.0,
|
||||
use_gae=True,
|
||||
):
|
||||
|
||||
import higher
|
||||
|
||||
self.config = config
|
||||
self.num_tasks = num_tasks
|
||||
self.inner_adaptation_steps = inner_adaptation_steps
|
||||
|
@ -266,86 +266,6 @@ class MAMLLoss(object):
|
|||
return placeholder_list
|
||||
|
||||
|
||||
def maml_loss(policy, model, dist_class, train_batch):
|
||||
logits, state = model(train_batch)
|
||||
policy.cur_lr = policy.config["lr"]
|
||||
|
||||
if policy.config["worker_index"]:
|
||||
policy.loss_obj = WorkerLoss(
|
||||
model=model,
|
||||
dist_class=dist_class,
|
||||
actions=train_batch[SampleBatch.ACTIONS],
|
||||
curr_logits=logits,
|
||||
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
advantages=train_batch[Postprocessing.ADVANTAGES],
|
||||
value_fn=model.value_function(),
|
||||
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
|
||||
vf_preds=train_batch[SampleBatch.VF_PREDS],
|
||||
cur_kl_coeff=0.0,
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_param=policy.config["clip_param"],
|
||||
vf_clip_param=policy.config["vf_clip_param"],
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
clip_loss=False,
|
||||
)
|
||||
else:
|
||||
policy.var_list = model.named_parameters()
|
||||
|
||||
# `split` may not exist yet (during test-loss call), use a dummy value.
|
||||
# Cannot use get here due to train_batch being a TrackingDict.
|
||||
if "split" in train_batch:
|
||||
split = train_batch["split"]
|
||||
else:
|
||||
split_shape = (
|
||||
policy.config["inner_adaptation_steps"],
|
||||
policy.config["num_workers"],
|
||||
)
|
||||
split_const = int(
|
||||
train_batch["obs"].shape[0] // (split_shape[0] * split_shape[1])
|
||||
)
|
||||
split = torch.ones(split_shape, dtype=int) * split_const
|
||||
policy.loss_obj = MAMLLoss(
|
||||
model=model,
|
||||
dist_class=dist_class,
|
||||
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
|
||||
advantages=train_batch[Postprocessing.ADVANTAGES],
|
||||
actions=train_batch[SampleBatch.ACTIONS],
|
||||
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
vf_preds=train_batch[SampleBatch.VF_PREDS],
|
||||
cur_kl_coeff=policy.kl_coeff_val,
|
||||
policy_vars=policy.var_list,
|
||||
obs=train_batch[SampleBatch.CUR_OBS],
|
||||
num_tasks=policy.config["num_workers"],
|
||||
split=split,
|
||||
config=policy.config,
|
||||
inner_adaptation_steps=policy.config["inner_adaptation_steps"],
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_param=policy.config["clip_param"],
|
||||
vf_clip_param=policy.config["vf_clip_param"],
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
use_gae=policy.config["use_gae"],
|
||||
meta_opt=policy.meta_opt,
|
||||
)
|
||||
|
||||
return policy.loss_obj.loss
|
||||
|
||||
|
||||
def maml_stats(policy, train_batch):
|
||||
if policy.config["worker_index"]:
|
||||
return {"worker_loss": policy.loss_obj.loss}
|
||||
else:
|
||||
return {
|
||||
"cur_kl_coeff": policy.kl_coeff_val,
|
||||
"cur_lr": policy.cur_lr,
|
||||
"total_loss": policy.loss_obj.loss,
|
||||
"policy_loss": policy.loss_obj.mean_policy_loss,
|
||||
"vf_loss": policy.loss_obj.mean_vf_loss,
|
||||
"kl_loss": policy.loss_obj.mean_kl_loss,
|
||||
"inner_kl": policy.loss_obj.mean_inner_kl,
|
||||
"entropy": policy.loss_obj.mean_entropy,
|
||||
}
|
||||
|
||||
|
||||
class KLCoeffMixin:
|
||||
def __init__(self, config):
|
||||
self.kl_coeff_val = (
|
||||
|
@ -364,33 +284,141 @@ class KLCoeffMixin:
|
|||
return self.kl_coeff_val
|
||||
|
||||
|
||||
def maml_optimizer_fn(policy, config):
|
||||
"""
|
||||
Workers use simple SGD for inner adaptation
|
||||
Meta-Policy uses Adam optimizer for meta-update
|
||||
"""
|
||||
if not config["worker_index"]:
|
||||
policy.meta_opt = torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
|
||||
return policy.meta_opt
|
||||
return torch.optim.SGD(policy.model.parameters(), lr=config["inner_lr"])
|
||||
class MAMLTorchPolicy(ComputeGAEMixIn, ValueNetworkMixin, KLCoeffMixin, TorchPolicyV2):
|
||||
"""PyTorch policy class used with MAMLTrainer."""
|
||||
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, **config)
|
||||
validate_config(self, observation_space, action_space, config)
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
ValueNetworkMixin.__init__(policy, config)
|
||||
KLCoeffMixin.__init__(policy, config)
|
||||
TorchPolicyV2.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
max_seq_len=config["model"]["max_seq_len"],
|
||||
)
|
||||
|
||||
ComputeGAEMixIn.__init__(self)
|
||||
KLCoeffMixin.__init__(self, config)
|
||||
ValueNetworkMixin.__init__(self, config)
|
||||
|
||||
MAMLTorchPolicy = build_policy_class(
|
||||
name="MAMLTorchPolicy",
|
||||
framework="torch",
|
||||
get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
|
||||
loss_fn=maml_loss,
|
||||
stats_fn=maml_stats,
|
||||
optimizer_fn=maml_optimizer_fn,
|
||||
extra_action_out_fn=vf_preds_fetches,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
before_init=setup_config,
|
||||
after_init=setup_mixins,
|
||||
mixins=[KLCoeffMixin],
|
||||
)
|
||||
# TODO: Don't require users to call this manually.
|
||||
self._initialize_loss_from_dummy_batch()
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def loss(
|
||||
self,
|
||||
model: ModelV2,
|
||||
dist_class: Type[TorchDistributionWrapper],
|
||||
train_batch: SampleBatch,
|
||||
) -> Union[TensorType, List[TensorType]]:
|
||||
"""Constructs the loss function.
|
||||
|
||||
Args:
|
||||
model: The Model to calculate the loss for.
|
||||
dist_class: The action distr. class.
|
||||
train_batch: The training data.
|
||||
|
||||
Returns:
|
||||
The PPO loss tensor given the input batch.
|
||||
"""
|
||||
logits, state = model(train_batch)
|
||||
self.cur_lr = self.config["lr"]
|
||||
|
||||
if self.config["worker_index"]:
|
||||
self.loss_obj = WorkerLoss(
|
||||
model=model,
|
||||
dist_class=dist_class,
|
||||
actions=train_batch[SampleBatch.ACTIONS],
|
||||
curr_logits=logits,
|
||||
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
advantages=train_batch[Postprocessing.ADVANTAGES],
|
||||
value_fn=model.value_function(),
|
||||
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
|
||||
vf_preds=train_batch[SampleBatch.VF_PREDS],
|
||||
cur_kl_coeff=0.0,
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_param=self.config["clip_param"],
|
||||
vf_clip_param=self.config["vf_clip_param"],
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
clip_loss=False,
|
||||
)
|
||||
else:
|
||||
self.var_list = model.named_parameters()
|
||||
|
||||
# `split` may not exist yet (during test-loss call), use a dummy value.
|
||||
# Cannot use get here due to train_batch being a TrackingDict.
|
||||
if "split" in train_batch:
|
||||
split = train_batch["split"]
|
||||
else:
|
||||
split_shape = (
|
||||
self.config["inner_adaptation_steps"],
|
||||
self.config["num_workers"],
|
||||
)
|
||||
split_const = int(
|
||||
train_batch["obs"].shape[0] // (split_shape[0] * split_shape[1])
|
||||
)
|
||||
split = torch.ones(split_shape, dtype=int) * split_const
|
||||
self.loss_obj = MAMLLoss(
|
||||
model=model,
|
||||
dist_class=dist_class,
|
||||
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
|
||||
advantages=train_batch[Postprocessing.ADVANTAGES],
|
||||
actions=train_batch[SampleBatch.ACTIONS],
|
||||
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
vf_preds=train_batch[SampleBatch.VF_PREDS],
|
||||
cur_kl_coeff=self.kl_coeff_val,
|
||||
policy_vars=self.var_list,
|
||||
obs=train_batch[SampleBatch.CUR_OBS],
|
||||
num_tasks=self.config["num_workers"],
|
||||
split=split,
|
||||
config=self.config,
|
||||
inner_adaptation_steps=self.config["inner_adaptation_steps"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_param=self.config["clip_param"],
|
||||
vf_clip_param=self.config["vf_clip_param"],
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
use_gae=self.config["use_gae"],
|
||||
meta_opt=self.meta_opt,
|
||||
)
|
||||
|
||||
return self.loss_obj.loss
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def optimizer(
|
||||
self,
|
||||
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
|
||||
"""
|
||||
Workers use simple SGD for inner adaptation
|
||||
Meta-Policy uses Adam optimizer for meta-update
|
||||
"""
|
||||
if not self.config["worker_index"]:
|
||||
self.meta_opt = torch.optim.Adam(
|
||||
self.model.parameters(), lr=self.config["lr"]
|
||||
)
|
||||
return self.meta_opt
|
||||
return torch.optim.SGD(self.model.parameters(), lr=self.config["inner_lr"])
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
if self.config["worker_index"]:
|
||||
return convert_to_numpy({"worker_loss": self.loss_obj.loss})
|
||||
else:
|
||||
return convert_to_numpy(
|
||||
{
|
||||
"cur_kl_coeff": self.kl_coeff_val,
|
||||
"cur_lr": self.cur_lr,
|
||||
"total_loss": self.loss_obj.loss,
|
||||
"policy_loss": self.loss_obj.mean_policy_loss,
|
||||
"vf_loss": self.loss_obj.mean_vf_loss,
|
||||
"kl_loss": self.loss_obj.mean_kl_loss,
|
||||
"inner_kl": self.loss_obj.mean_inner_kl,
|
||||
"entropy": self.loss_obj.mean_entropy,
|
||||
}
|
||||
)
|
||||
|
||||
def extra_grad_process(
|
||||
self, optimizer: "torch.optim.Optimizer", loss: TensorType
|
||||
) -> Dict[str, TensorType]:
|
||||
return apply_grad_clipping(self, optimizer, loss)
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import (
|
||||
MARWILDynamicTFPolicy,
|
||||
MARWILEagerTFPolicy,
|
||||
)
|
||||
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"BCTrainer",
|
||||
"BC_DEFAULT_CONFIG",
|
||||
"DEFAULT_CONFIG",
|
||||
"MARWILTFPolicy",
|
||||
"MARWILDynamicTFPolicy",
|
||||
"MARWILEagerTFPolicy",
|
||||
"MARWILTorchPolicy",
|
||||
"MARWILTrainer",
|
||||
]
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Type
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
synchronous_parallel_sample,
|
||||
|
@ -123,8 +122,16 @@ class MARWILTrainer(Trainer):
|
|||
)
|
||||
|
||||
return MARWILTorchPolicy
|
||||
elif config["framework"] == "tf":
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import (
|
||||
MARWILDynamicTFPolicy,
|
||||
)
|
||||
|
||||
return MARWILDynamicTFPolicy
|
||||
else:
|
||||
return MARWILTFPolicy
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILEagerTFPolicy
|
||||
|
||||
return MARWILEagerTFPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
|
|
|
@ -1,100 +1,69 @@
|
|||
import logging
|
||||
import gym
|
||||
from typing import Optional, Dict
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import compute_and_clip_gradients
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils.framework import try_import_tf, get_variable
|
||||
from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, PolicyID
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
|
||||
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_mixins import ComputeAndClipGradsMixIn, ValueNetworkMixin
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf, get_variable
|
||||
from ray.rllib.utils.tf_utils import explained_variance
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValueNetworkMixin:
|
||||
def __init__(
|
||||
class PostprocessAdvantages:
|
||||
"""Marwil's custom trajectory post-processing mixin."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def postprocess_trajectory(
|
||||
self,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
|
||||
episode: Optional["Episode"] = None,
|
||||
):
|
||||
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
@make_tf_callable(self.get_session())
|
||||
def value(**input_dict):
|
||||
model_out, _ = self.model(input_dict)
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0]
|
||||
|
||||
self._value = value
|
||||
|
||||
|
||||
def postprocess_advantages(
|
||||
policy: Policy,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
|
||||
episode=None,
|
||||
) -> SampleBatch:
|
||||
"""Postprocesses a trajectory and returns the processed trajectory.
|
||||
|
||||
The trajectory contains only data from one episode and from one agent.
|
||||
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
||||
contain a truncated (at-the-end) episode, in case the
|
||||
`config.rollout_fragment_length` was reached by the sampler.
|
||||
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
||||
exactly one episode (no matter how long).
|
||||
New columns can be added to sample_batch and existing ones may be altered.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy used to generate the trajectory
|
||||
(`sample_batch`)
|
||||
sample_batch (SampleBatch): The SampleBatch to postprocess.
|
||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||
same episode). NOTE: The other agents use the same policy.
|
||||
episode (Optional[Episode]): Optional multi-agent episode
|
||||
object in which the agents operated.
|
||||
|
||||
Returns:
|
||||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||
"""
|
||||
|
||||
# Trajectory is actually complete -> last r=0.0.
|
||||
if sample_batch[SampleBatch.DONES][-1]:
|
||||
last_r = 0.0
|
||||
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
||||
else:
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
# Create an input dict according to the Model's requirements.
|
||||
index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
|
||||
input_dict = sample_batch.get_single_step_input_dict(
|
||||
policy.model.view_requirements, index=index
|
||||
sample_batch = super().postprocess_trajectory(
|
||||
sample_batch, other_agent_batches, episode
|
||||
)
|
||||
last_r = policy._value(**input_dict)
|
||||
|
||||
# Adds the "advantages" (which in the case of MARWIL are simply the
|
||||
# discounted cummulative rewards) to the SampleBatch.
|
||||
return compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
policy.config["gamma"],
|
||||
# We just want the discounted cummulative rewards, so we won't need
|
||||
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
|
||||
use_gae=False,
|
||||
use_critic=False,
|
||||
)
|
||||
# Trajectory is actually complete -> last r=0.0.
|
||||
if sample_batch[SampleBatch.DONES][-1]:
|
||||
last_r = 0.0
|
||||
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
||||
else:
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
# Create an input dict according to the Model's requirements.
|
||||
index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
|
||||
input_dict = sample_batch.get_single_step_input_dict(
|
||||
self.model.view_requirements, index=index
|
||||
)
|
||||
last_r = self._value(**input_dict)
|
||||
|
||||
# Adds the "advantages" (which in the case of MARWIL are simply the
|
||||
# discounted cummulative rewards) to the SampleBatch.
|
||||
return compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
self.config["gamma"],
|
||||
# We just want the discounted cummulative rewards, so we won't need
|
||||
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
|
||||
use_gae=False,
|
||||
use_critic=False,
|
||||
)
|
||||
|
||||
|
||||
class MARWILLoss:
|
||||
|
@ -174,69 +143,100 @@ class MARWILLoss:
|
|||
self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss
|
||||
|
||||
|
||||
def marwil_loss(
|
||||
policy: Policy,
|
||||
model: ModelV2,
|
||||
dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch,
|
||||
) -> TensorType:
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
value_estimates = model.value_function()
|
||||
# We need this builder function because we want to share the same
|
||||
# custom logics between TF1 dynamic and TF2 eager policies.
|
||||
def get_marwil_tf_policy(base: type) -> type:
|
||||
"""Construct a MARWILTFPolicy inheriting either dynamic or eager base policies.
|
||||
|
||||
policy.loss = MARWILLoss(
|
||||
policy,
|
||||
value_estimates,
|
||||
action_dist,
|
||||
train_batch,
|
||||
policy.config["vf_coeff"],
|
||||
policy.config["beta"],
|
||||
)
|
||||
Args:
|
||||
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
|
||||
|
||||
return policy.loss.total_loss
|
||||
Returns:
|
||||
A TF Policy to be used with MAMLTrainer.
|
||||
"""
|
||||
|
||||
class MARWILTFPolicy(
|
||||
ComputeAndClipGradsMixIn, ValueNetworkMixin, PostprocessAdvantages, base
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_model=None,
|
||||
existing_inputs=None,
|
||||
):
|
||||
# First thing first, enable eager execution if necessary.
|
||||
base.enable_eager_execution_if_necessary()
|
||||
|
||||
config = dict(ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG, **config)
|
||||
|
||||
# Initialize base class.
|
||||
base.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_inputs=existing_inputs,
|
||||
existing_model=existing_model,
|
||||
)
|
||||
|
||||
ComputeAndClipGradsMixIn.__init__(self)
|
||||
ValueNetworkMixin.__init__(self, config)
|
||||
PostprocessAdvantages.__init__(self)
|
||||
|
||||
# Not needed for pure BC.
|
||||
if config["beta"] != 0.0:
|
||||
# Set up a tf-var for the moving avg (do this here to make it work
|
||||
# with eager mode); "c^2" in the paper.
|
||||
self._moving_average_sqd_adv_norm = get_variable(
|
||||
config["moving_average_sqd_adv_norm_start"],
|
||||
framework="tf",
|
||||
tf_name="moving_average_of_advantage_norm",
|
||||
trainable=False,
|
||||
)
|
||||
|
||||
# Note: this is a bit ugly, but loss and optimizer initialization must
|
||||
# happen after all the MixIns are initialized.
|
||||
self.maybe_initialize_optimizer_and_loss()
|
||||
|
||||
@override(base)
|
||||
def loss(
|
||||
self,
|
||||
model: Union[ModelV2, "tf.keras.Model"],
|
||||
dist_class: Type[TFActionDistribution],
|
||||
train_batch: SampleBatch,
|
||||
) -> Union[TensorType, List[TensorType]]:
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
value_estimates = model.value_function()
|
||||
|
||||
self.loss = MARWILLoss(
|
||||
self,
|
||||
value_estimates,
|
||||
action_dist,
|
||||
train_batch,
|
||||
self.config["vf_coeff"],
|
||||
self.config["beta"],
|
||||
)
|
||||
|
||||
return self.loss.total_loss
|
||||
|
||||
@override(base)
|
||||
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
stats = {
|
||||
"policy_loss": self.loss.p_loss,
|
||||
"total_loss": self.loss.total_loss,
|
||||
}
|
||||
if self.config["beta"] != 0.0:
|
||||
stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm
|
||||
stats["vf_explained_var"] = self.loss.explained_variance
|
||||
stats["vf_loss"] = self.loss.v_loss
|
||||
|
||||
return stats
|
||||
|
||||
return MARWILTFPolicy
|
||||
|
||||
|
||||
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
stats = {
|
||||
"policy_loss": policy.loss.p_loss,
|
||||
"total_loss": policy.loss.total_loss,
|
||||
}
|
||||
if policy.config["beta"] != 0.0:
|
||||
stats["moving_average_sqd_adv_norm"] = policy._moving_average_sqd_adv_norm
|
||||
stats["vf_explained_var"] = policy.loss.explained_variance
|
||||
stats["vf_loss"] = policy.loss.v_loss
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def setup_mixins(
|
||||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
) -> None:
|
||||
# Setup Value branch of our NN.
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
# Not needed for pure BC.
|
||||
if policy.config["beta"] != 0.0:
|
||||
# Set up a tf-var for the moving avg (do this here to make it work
|
||||
# with eager mode); "c^2" in the paper.
|
||||
policy._moving_average_sqd_adv_norm = get_variable(
|
||||
policy.config["moving_average_sqd_adv_norm_start"],
|
||||
framework="tf",
|
||||
tf_name="moving_average_of_advantage_norm",
|
||||
trainable=False,
|
||||
)
|
||||
|
||||
|
||||
MARWILTFPolicy = build_tf_policy(
|
||||
name="MARWILTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
|
||||
loss_fn=marwil_loss,
|
||||
stats_fn=stats,
|
||||
postprocess_fn=postprocess_advantages,
|
||||
before_loss_init=setup_mixins,
|
||||
compute_gradients_fn=compute_and_clip_gradients,
|
||||
mixins=[ValueNetworkMixin],
|
||||
)
|
||||
MARWILDynamicTFPolicy = get_marwil_tf_policy(DynamicTFPolicyV2)
|
||||
MARWILEagerTFPolicy = get_marwil_tf_policy(EagerTFPolicyV2)
|
||||
|
|
|
@ -1,122 +1,124 @@
|
|||
import gym
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import postprocess_advantages
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import PostprocessAdvantages
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_mixins import ValueNetworkMixin
|
||||
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.torch_utils import apply_grad_clipping, explained_variance
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, TensorType
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
def marwil_loss(
|
||||
policy: Policy,
|
||||
model: ModelV2,
|
||||
dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch,
|
||||
) -> TensorType:
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
actions = train_batch[SampleBatch.ACTIONS]
|
||||
# log\pi_\theta(a|s)
|
||||
logprobs = action_dist.logp(actions)
|
||||
class MARWILTorchPolicy(ValueNetworkMixin, PostprocessAdvantages, TorchPolicyV2):
|
||||
"""PyTorch policy class used with MarwilTrainer."""
|
||||
|
||||
# Advantage estimation.
|
||||
if policy.config["beta"] != 0.0:
|
||||
cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
|
||||
state_values = model.value_function()
|
||||
adv = cumulative_rewards - state_values
|
||||
adv_squared_mean = torch.mean(torch.pow(adv, 2.0))
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG, **config)
|
||||
|
||||
explained_var = explained_variance(cumulative_rewards, state_values)
|
||||
policy.explained_variance = torch.mean(explained_var)
|
||||
|
||||
# Policy loss.
|
||||
# Update averaged advantage norm.
|
||||
rate = policy.config["moving_average_sqd_adv_norm_update_rate"]
|
||||
policy._moving_average_sqd_adv_norm.add_(
|
||||
rate * (adv_squared_mean - policy._moving_average_sqd_adv_norm)
|
||||
TorchPolicyV2.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
max_seq_len=config["model"]["max_seq_len"],
|
||||
)
|
||||
# Exponentially weighted advantages.
|
||||
exp_advs = torch.exp(
|
||||
policy.config["beta"]
|
||||
* (adv / (1e-8 + torch.pow(policy._moving_average_sqd_adv_norm, 0.5)))
|
||||
).detach()
|
||||
# Value loss.
|
||||
policy.v_loss = 0.5 * adv_squared_mean
|
||||
else:
|
||||
# Policy loss (simple BC loss term).
|
||||
exp_advs = 1.0
|
||||
# Value loss.
|
||||
policy.v_loss = 0.0
|
||||
|
||||
# logprob loss alone tends to push action distributions to
|
||||
# have very low entropy, resulting in worse performance for
|
||||
# unfamiliar situations.
|
||||
# A scaled logstd loss term encourages stochasticity, thus
|
||||
# alleviate the problem to some extent.
|
||||
logstd_coeff = policy.config["bc_logstd_coeff"]
|
||||
if logstd_coeff > 0.0:
|
||||
logstds = torch.mean(action_dist.log_std, dim=1)
|
||||
else:
|
||||
logstds = 0.0
|
||||
ValueNetworkMixin.__init__(self, config)
|
||||
PostprocessAdvantages.__init__(self)
|
||||
|
||||
policy.p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds))
|
||||
# Not needed for pure BC.
|
||||
if config["beta"] != 0.0:
|
||||
# Set up a torch-var for the squared moving avg. advantage norm.
|
||||
self._moving_average_sqd_adv_norm = torch.tensor(
|
||||
[config["moving_average_sqd_adv_norm_start"]],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False,
|
||||
).to(self.device)
|
||||
|
||||
# Combine both losses.
|
||||
policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * policy.v_loss
|
||||
# TODO: Don't require users to call this manually.
|
||||
self._initialize_loss_from_dummy_batch()
|
||||
|
||||
return policy.total_loss
|
||||
@override(TorchPolicyV2)
|
||||
def loss(
|
||||
self,
|
||||
model: ModelV2,
|
||||
dist_class: Type[TorchDistributionWrapper],
|
||||
train_batch: SampleBatch,
|
||||
) -> Union[TensorType, List[TensorType]]:
|
||||
model_out, _ = model(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
actions = train_batch[SampleBatch.ACTIONS]
|
||||
# log\pi_\theta(a|s)
|
||||
logprobs = action_dist.logp(actions)
|
||||
|
||||
# Advantage estimation.
|
||||
if self.config["beta"] != 0.0:
|
||||
cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
|
||||
state_values = model.value_function()
|
||||
adv = cumulative_rewards - state_values
|
||||
adv_squared_mean = torch.mean(torch.pow(adv, 2.0))
|
||||
|
||||
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
stats = {
|
||||
"policy_loss": policy.p_loss,
|
||||
"total_loss": policy.total_loss,
|
||||
}
|
||||
if policy.config["beta"] != 0.0:
|
||||
stats["moving_average_sqd_adv_norm"] = policy._moving_average_sqd_adv_norm
|
||||
stats["vf_explained_var"] = policy.explained_variance
|
||||
stats["vf_loss"] = policy.v_loss
|
||||
explained_var = explained_variance(cumulative_rewards, state_values)
|
||||
self.explained_variance = torch.mean(explained_var)
|
||||
|
||||
return stats
|
||||
# Policy loss.
|
||||
# Update averaged advantage norm.
|
||||
rate = self.config["moving_average_sqd_adv_norm_update_rate"]
|
||||
self._moving_average_sqd_adv_norm.add_(
|
||||
rate * (adv_squared_mean - self._moving_average_sqd_adv_norm)
|
||||
)
|
||||
# Exponentially weighted advantages.
|
||||
exp_advs = torch.exp(
|
||||
self.config["beta"]
|
||||
* (adv / (1e-8 + torch.pow(self._moving_average_sqd_adv_norm, 0.5)))
|
||||
).detach()
|
||||
# Value loss.
|
||||
self.v_loss = 0.5 * adv_squared_mean
|
||||
else:
|
||||
# Policy loss (simple BC loss term).
|
||||
exp_advs = 1.0
|
||||
# Value loss.
|
||||
self.v_loss = 0.0
|
||||
|
||||
# logprob loss alone tends to push action distributions to
|
||||
# have very low entropy, resulting in worse performance for
|
||||
# unfamiliar situations.
|
||||
# A scaled logstd loss term encourages stochasticity, thus
|
||||
# alleviate the problem to some extent.
|
||||
logstd_coeff = self.config["bc_logstd_coeff"]
|
||||
if logstd_coeff > 0.0:
|
||||
logstds = torch.mean(action_dist.log_std, dim=1)
|
||||
else:
|
||||
logstds = 0.0
|
||||
|
||||
def setup_mixins(
|
||||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
) -> None:
|
||||
# Setup Value branch of our NN.
|
||||
ValueNetworkMixin.__init__(policy, config)
|
||||
self.p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds))
|
||||
|
||||
# Not needed for pure BC.
|
||||
if policy.config["beta"] != 0.0:
|
||||
# Set up a torch-var for the squared moving avg. advantage norm.
|
||||
policy._moving_average_sqd_adv_norm = torch.tensor(
|
||||
[policy.config["moving_average_sqd_adv_norm_start"]],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False,
|
||||
).to(policy.device)
|
||||
# Combine both losses.
|
||||
self.total_loss = self.p_loss + self.config["vf_coeff"] * self.v_loss
|
||||
|
||||
return self.total_loss
|
||||
|
||||
MARWILTorchPolicy = build_policy_class(
|
||||
name="MARWILTorchPolicy",
|
||||
framework="torch",
|
||||
loss_fn=marwil_loss,
|
||||
get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
|
||||
stats_fn=stats,
|
||||
postprocess_fn=postprocess_advantages,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[ValueNetworkMixin],
|
||||
)
|
||||
@override(TorchPolicyV2)
|
||||
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
stats = {
|
||||
"policy_loss": self.p_loss,
|
||||
"total_loss": self.total_loss,
|
||||
}
|
||||
if self.config["beta"] != 0.0:
|
||||
stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm
|
||||
stats["vf_explained_var"] = self.explained_variance
|
||||
stats["vf_loss"] = self.v_loss
|
||||
return convert_to_numpy(stats)
|
||||
|
||||
def extra_grad_process(
|
||||
self, optimizer: "torch.optim.Optimizer", loss: TensorType
|
||||
) -> Dict[str, TensorType]:
|
||||
return apply_grad_clipping(self, optimizer, loss)
|
||||
|
|
|
@ -5,6 +5,8 @@ import unittest
|
|||
|
||||
import ray
|
||||
import ray.rllib.algorithms.marwil as marwil
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILEagerTFPolicy
|
||||
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.offline import JsonReader
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
@ -182,9 +184,7 @@ class TestMARWIL(unittest.TestCase):
|
|||
batch.set_get_interceptor(None)
|
||||
postprocessed_batch = policy.postprocess_trajectory(batch)
|
||||
loss_func = (
|
||||
marwil.marwil_tf_policy.marwil_loss
|
||||
if fw != "torch"
|
||||
else marwil.marwil_torch_policy.marwil_loss
|
||||
MARWILEagerTFPolicy.loss if fw != "torch" else MARWILTorchPolicy.loss
|
||||
)
|
||||
if fw != "tf":
|
||||
policy._lazy_tensor_dict(postprocessed_batch)
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import List, Type
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
|
||||
from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
|
||||
from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
|
@ -374,6 +373,8 @@ class MBMPOTrainer(Trainer):
|
|||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
|
||||
|
||||
return MBMPOTorchPolicy
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -1,135 +1,90 @@
|
|||
import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
|
||||
from ray.rllib.algorithms.maml.maml_torch_policy import (
|
||||
setup_mixins,
|
||||
maml_loss,
|
||||
maml_stats,
|
||||
maml_optimizer_fn,
|
||||
KLCoeffMixin,
|
||||
)
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
|
||||
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_spaces(
|
||||
policy: Policy,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
) -> None:
|
||||
"""Validates the observation- and action spaces used for the Policy.
|
||||
class MBMPOTorchPolicy(MAMLTorchPolicy):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
# Validate spaces.
|
||||
# Only support single Box or single Discrete spaces.
|
||||
if not isinstance(action_space, (Box, Discrete)):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space ({}) of {} is not supported for "
|
||||
"MB-MPO. Must be [Box|Discrete].".format(action_space, self)
|
||||
)
|
||||
# If Box, make sure it's a 1D vector space.
|
||||
elif isinstance(action_space, Box) and len(action_space.shape) > 1:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space ({}) of {} has multiple dimensions "
|
||||
"{}. ".format(action_space, self, action_space.shape)
|
||||
+ "Consider reshaping this into a single dimension Box space "
|
||||
"or using the multi-agent API."
|
||||
)
|
||||
|
||||
Args:
|
||||
policy (Policy): The policy, whose spaces are being validated.
|
||||
observation_space (gym.spaces.Space): The observation space to
|
||||
validate.
|
||||
action_space (gym.spaces.Space): The action space to validate.
|
||||
config (TrainerConfigDict): The Policy's config dict.
|
||||
config = dict(ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG, **config)
|
||||
super().__init__(observation_space, action_space, config)
|
||||
|
||||
Raises:
|
||||
UnsupportedSpaceException: If one of the spaces is not supported.
|
||||
"""
|
||||
# Only support single Box or single Discrete spaces.
|
||||
if not isinstance(action_space, (Box, Discrete)):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space ({}) of {} is not supported for "
|
||||
"MB-MPO. Must be [Box|Discrete].".format(action_space, policy)
|
||||
)
|
||||
# If Box, make sure it's a 1D vector space.
|
||||
elif isinstance(action_space, Box) and len(action_space.shape) > 1:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space ({}) of {} has multiple dimensions "
|
||||
"{}. ".format(action_space, policy, action_space.shape)
|
||||
+ "Consider reshaping this into a single dimension Box space "
|
||||
"or using the multi-agent API."
|
||||
def make_model_and_action_dist(
|
||||
self,
|
||||
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
|
||||
"""Constructs the necessary ModelV2 and action dist class for the Policy.
|
||||
|
||||
Args:
|
||||
obs_space (gym.spaces.Space): The observation space.
|
||||
action_space (gym.spaces.Space): The action space.
|
||||
config (TrainerConfigDict): The SAC trainer's config dict.
|
||||
|
||||
Returns:
|
||||
ModelV2: The ModelV2 to be used by the Policy. Note: An additional
|
||||
target model will be created in this function and assigned to
|
||||
`policy.target_model`.
|
||||
"""
|
||||
# Get the output distribution class for predicting rewards and next-obs.
|
||||
self.distr_cls_next_obs, num_outputs = ModelCatalog.get_action_dist(
|
||||
self.observation_space,
|
||||
self.config,
|
||||
dist_type="deterministic",
|
||||
framework="torch",
|
||||
)
|
||||
|
||||
# Build one dynamics model if we are a Worker.
|
||||
# If we are the main MAML learner, build n (num_workers) dynamics Models
|
||||
# for being able to create checkpoints for the current state of training.
|
||||
device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
self.dynamics_model = ModelCatalog.get_model_v2(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=self.config["dynamics_model"],
|
||||
framework="torch",
|
||||
name="dynamics_ensemble",
|
||||
).to(device)
|
||||
|
||||
def make_model_and_action_dist(
|
||||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
|
||||
"""Constructs the necessary ModelV2 and action dist class for the Policy.
|
||||
action_dist, num_outputs = ModelCatalog.get_action_dist(
|
||||
self.action_space, self.config, framework="torch"
|
||||
)
|
||||
# Create the pi-model and register it with the Policy.
|
||||
self.pi = ModelCatalog.get_model_v2(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=self.config["model"],
|
||||
framework="torch",
|
||||
name="policy_model",
|
||||
)
|
||||
|
||||
Args:
|
||||
policy (Policy): The TFPolicy that will use the models.
|
||||
obs_space (gym.spaces.Space): The observation space.
|
||||
action_space (gym.spaces.Space): The action space.
|
||||
config (TrainerConfigDict): The SAC trainer's config dict.
|
||||
|
||||
Returns:
|
||||
ModelV2: The ModelV2 to be used by the Policy. Note: An additional
|
||||
target model will be created in this function and assigned to
|
||||
`policy.target_model`.
|
||||
"""
|
||||
# Get the output distribution class for predicting rewards and next-obs.
|
||||
policy.distr_cls_next_obs, num_outputs = ModelCatalog.get_action_dist(
|
||||
obs_space, config, dist_type="deterministic", framework="torch"
|
||||
)
|
||||
|
||||
# Build one dynamics model if we are a Worker.
|
||||
# If we are the main MAML learner, build n (num_workers) dynamics Models
|
||||
# for being able to create checkpoints for the current state of training.
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
policy.dynamics_model = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["dynamics_model"],
|
||||
framework="torch",
|
||||
name="dynamics_ensemble",
|
||||
).to(device)
|
||||
|
||||
action_dist, num_outputs = ModelCatalog.get_action_dist(
|
||||
action_space, config, framework="torch"
|
||||
)
|
||||
# Create the pi-model and register it with the Policy.
|
||||
policy.pi = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework="torch",
|
||||
name="policy_model",
|
||||
)
|
||||
|
||||
return policy.pi, action_dist
|
||||
|
||||
|
||||
# Build a child class of `TorchPolicy`, given the custom functions defined
|
||||
# above.
|
||||
MBMPOTorchPolicy = build_policy_class(
|
||||
name="MBMPOTorchPolicy",
|
||||
framework="torch",
|
||||
get_default_config=lambda: ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG,
|
||||
make_model_and_action_dist=make_model_and_action_dist,
|
||||
loss_fn=maml_loss,
|
||||
stats_fn=maml_stats,
|
||||
optimizer_fn=maml_optimizer_fn,
|
||||
extra_action_out_fn=vf_preds_fetches,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
before_init=setup_config,
|
||||
after_init=setup_mixins,
|
||||
mixins=[KLCoeffMixin],
|
||||
)
|
||||
return self.pi, action_dist
|
||||
|
|
|
@ -29,6 +29,8 @@ from ray.rllib.utils.annotations import (
|
|||
DeveloperAPI,
|
||||
ExperimentalAPI,
|
||||
OverrideToImplementCustomLogic,
|
||||
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
||||
is_overridden,
|
||||
)
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
|
@ -155,6 +157,20 @@ class Policy(metaclass=ABCMeta):
|
|||
# Child classes may set this.
|
||||
self.dist_class: Optional[Type] = None
|
||||
|
||||
# Initialize view requirements.
|
||||
self.init_view_requirements()
|
||||
|
||||
# Whether the Model's initial state (method) has been added
|
||||
# automatically based on the given view requirements of the model.
|
||||
self._model_init_state_automatically_added = False
|
||||
|
||||
@DeveloperAPI
|
||||
def init_view_requirements(self):
|
||||
"""Maximal view requirements dict for `learn_on_batch()` and
|
||||
`compute_actions` calls.
|
||||
Specific policies can override this function to provide custom
|
||||
list of view requirements.
|
||||
"""
|
||||
# Maximal view requirements dict for `learn_on_batch()` and
|
||||
# `compute_actions` calls.
|
||||
# View requirements will be automatically filtered out later based
|
||||
|
@ -167,9 +183,6 @@ class Policy(metaclass=ABCMeta):
|
|||
for k, v in view_reqs.items():
|
||||
if k not in self.view_requirements:
|
||||
self.view_requirements[k] = v
|
||||
# Whether the Model's initial state (method) has been added
|
||||
# automatically based on the given view requirements of the model.
|
||||
self._model_init_state_automatically_added = False
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_single_action(
|
||||
|
@ -413,6 +426,7 @@ class Policy(metaclass=ABCMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
||||
def postprocess_trajectory(
|
||||
self,
|
||||
sample_batch: SampleBatch,
|
||||
|
@ -739,9 +753,7 @@ class Policy(metaclass=ABCMeta):
|
|||
# steps).
|
||||
# Make sure, we keep global_timestep as a Tensor for tf-eager
|
||||
# (leads to memory leaks if not doing so).
|
||||
from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
|
||||
|
||||
if self.framework in ["tf2", "tfe"] and isinstance(self, EagerTFPolicy):
|
||||
if self.framework in ["tfe", "tf2"]:
|
||||
self.global_timestep.assign(global_vars["timestep"])
|
||||
else:
|
||||
self.global_timestep = global_vars["timestep"]
|
||||
|
@ -952,11 +964,19 @@ class Policy(metaclass=ABCMeta):
|
|||
train_batch[SampleBatch.SEQ_LENS] = seq_lens
|
||||
train_batch.count = self._dummy_batch.count
|
||||
# Call the loss function, if it exists.
|
||||
# TODO(jungong) : clean up after all agents get migrated.
|
||||
# We should simply do self.loss(...) here.
|
||||
if self._loss is not None:
|
||||
self._loss(self, self.model, self.dist_class, train_batch)
|
||||
elif is_overridden(self.loss):
|
||||
self.loss(self.model, self.dist_class, train_batch)
|
||||
# Call the stats fn, if given.
|
||||
# TODO(jungong) : clean up after all agents get migrated.
|
||||
# We should simply do self.stats_fn(train_batch) here.
|
||||
if stats_fn is not None:
|
||||
stats_fn(self, train_batch)
|
||||
if hasattr(self, "stats_fn"):
|
||||
self.stats_fn(train_batch)
|
||||
|
||||
# Re-enable tracing.
|
||||
self._no_tracing = False
|
||||
|
@ -974,7 +994,7 @@ class Policy(metaclass=ABCMeta):
|
|||
self.view_requirements[key] = ViewRequirement(
|
||||
used_for_compute_actions=False
|
||||
)
|
||||
if self._loss:
|
||||
if self._loss or is_overridden(self.loss):
|
||||
# Tag those only needed for post-processing (with some
|
||||
# exceptions).
|
||||
for key in self._dummy_batch.accessed_keys:
|
||||
|
|
|
@ -150,7 +150,7 @@ class TFPolicy(Policy):
|
|||
super().__init__(observation_space, action_space, config)
|
||||
|
||||
# Get devices to build the graph on.
|
||||
worker_idx = self.config.get("worker_index", 0)
|
||||
worker_idx = config.get("worker_index", 0)
|
||||
if not config["_fake_gpus"] and ray.worker._mode() == ray.worker.LOCAL_MODE:
|
||||
num_gpus = 0
|
||||
elif worker_idx == 0:
|
||||
|
@ -237,7 +237,7 @@ class TFPolicy(Policy):
|
|||
self._action_input = action_input # For logp calculations.
|
||||
self._dist_inputs = dist_inputs
|
||||
self.dist_class = dist_class
|
||||
|
||||
self._cached_extra_action_out = None
|
||||
self._state_inputs = state_inputs or []
|
||||
self._state_outputs = state_outputs or []
|
||||
self._seq_lens = seq_lens
|
||||
|
@ -784,6 +784,16 @@ class TFPolicy(Policy):
|
|||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_fetches(self) -> Dict[str, TensorType]:
|
||||
# Cache graph fetches for action computation for better
|
||||
# performance.
|
||||
# This function is called every time the static graph is run
|
||||
# to compute actions.
|
||||
if not self._cached_extra_action_out:
|
||||
self._cached_extra_action_out = self.extra_action_out_fn()
|
||||
return self._cached_extra_action_out
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_action_out_fn(self) -> Dict[str, TensorType]:
|
||||
"""Extra values to fetch and return from compute_actions().
|
||||
|
||||
By default we return action probability/log-likelihood info
|
||||
|
|
|
@ -230,7 +230,7 @@ class TorchPolicyV2(Policy):
|
|||
dist_class: Type[TorchDistributionWrapper],
|
||||
train_batch: SampleBatch,
|
||||
) -> Union[TensorType, List[TensorType]]:
|
||||
"""Constructs the loss for Proximal Policy Objective.
|
||||
"""Constructs the loss function.
|
||||
|
||||
Args:
|
||||
model: The Model to calculate the loss for.
|
||||
|
@ -396,20 +396,6 @@ class TorchPolicyV2(Policy):
|
|||
"""
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
||||
def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
"""Return dict of extra grad info.
|
||||
|
||||
Args:
|
||||
train_batch: The training batch for which to produce
|
||||
extra grad info for.
|
||||
|
||||
Returns:
|
||||
The info dict carrying grad info per str key.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
||||
|
@ -418,8 +404,28 @@ class TorchPolicyV2(Policy):
|
|||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
|
||||
episode: Optional["Episode"] = None,
|
||||
):
|
||||
"""Additional custom postprocessing of SampleBatch."""
|
||||
) -> SampleBatch:
|
||||
"""Postprocesses a trajectory and returns the processed trajectory.
|
||||
|
||||
The trajectory contains only data from one episode and from one agent.
|
||||
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
||||
contain a truncated (at-the-end) episode, in case the
|
||||
`config.rollout_fragment_length` was reached by the sampler.
|
||||
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
||||
exactly one episode (no matter how long).
|
||||
New columns can be added to sample_batch and existing ones may be altered.
|
||||
|
||||
Args:
|
||||
sample_batch (SampleBatch): The SampleBatch to postprocess.
|
||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||
same episode). NOTE: The other agents use the same policy.
|
||||
episode (Optional[Episode]): Optional multi-agent episode
|
||||
object in which the agents operated.
|
||||
|
||||
Returns:
|
||||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||
"""
|
||||
return sample_batch
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -775,7 +781,7 @@ class TorchPolicyV2(Policy):
|
|||
for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
|
||||
batch_fetches[f"tower_{i}"].update(
|
||||
{
|
||||
LEARNER_STATS_KEY: self.extra_grad_info(batch),
|
||||
LEARNER_STATS_KEY: self.stats_fn(batch),
|
||||
"model": model.metrics(),
|
||||
}
|
||||
)
|
||||
|
@ -810,7 +816,7 @@ class TorchPolicyV2(Policy):
|
|||
all_grads, grad_info = tower_outputs[0]
|
||||
|
||||
grad_info["allreduce_latency"] /= len(self._optimizers)
|
||||
grad_info.update(self.extra_grad_info(postprocessed_batch))
|
||||
grad_info.update(self.stats_fn(postprocessed_batch))
|
||||
|
||||
fetches = self.extra_compute_grad_fetches()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue