mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] PPO, APPO, and DD-PPO code cleanup. (#10420)
This commit is contained in:
parent
f10a5a40b0
commit
ef18893fb5
21 changed files with 1159 additions and 947 deletions
|
@ -372,13 +372,14 @@ def from_importance_weights(log_rhos,
|
|||
return delta_t + discount_t * c_t * acc
|
||||
|
||||
initial_values = tf.zeros_like(bootstrap_value)
|
||||
vs_minus_v_xs = tf.scan(
|
||||
fn=scanfunc,
|
||||
elems=sequences,
|
||||
initializer=initial_values,
|
||||
parallel_iterations=1,
|
||||
back_prop=False,
|
||||
name="scan")
|
||||
vs_minus_v_xs = tf.nest.map_structure(
|
||||
tf.stop_gradient,
|
||||
tf.scan(
|
||||
fn=scanfunc,
|
||||
elems=sequences,
|
||||
initializer=initial_values,
|
||||
parallel_iterations=1,
|
||||
name="scan"))
|
||||
# Reverse the results back to original order.
|
||||
vs_minus_v_xs = tf.reverse(vs_minus_v_xs, [0], name="vs_minus_v_xs")
|
||||
|
||||
|
|
|
@ -6,7 +6,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
||||
vf_preds_fetches, clip_gradients, setup_config, ValueNetworkMixin
|
||||
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
|
||||
ValueNetworkMixin
|
||||
from ray.rllib.utils.framework import get_activation_fn
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
@ -421,7 +422,7 @@ MAMLTFPolicy = build_tf_policy(
|
|||
optimizer_fn=maml_optimizer_fn,
|
||||
extra_action_fetches_fn=vf_preds_fetches,
|
||||
postprocess_fn=postprocess_ppo_gae,
|
||||
gradients_fn=clip_gradients,
|
||||
gradients_fn=compute_and_clip_gradients,
|
||||
before_init=setup_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[KLCoeffMixin])
|
||||
|
|
|
@ -46,7 +46,7 @@ def pg_tf_loss(
|
|||
train_batch[Postprocessing.ADVANTAGES], dtype=tf.float32))
|
||||
|
||||
|
||||
# Build a child class of `TFPolicy`, given the extra options:
|
||||
# Build a child class of `DynamicTFPolicy`, given the extra options:
|
||||
# - trajectory post-processing function (to calculate advantages)
|
||||
# - PG loss function
|
||||
PGTFPolicy = build_tf_policy(
|
||||
|
|
23
rllib/agents/ppo/README.md
Normal file
23
rllib/agents/ppo/README.md
Normal file
|
@ -0,0 +1,23 @@
|
|||
Proximal Policy Optimization (PPO)
|
||||
==================================
|
||||
|
||||
Implementations of:
|
||||
|
||||
1) Proximal Policy Optimization (PPO).
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/latest/rllib-algorithms.html#ppo)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo.py)**
|
||||
|
||||
2) Asynchronous Proximal Policy Optimization (APPO).
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/latest/rllib-algorithms.html#appo)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/appo.py)**
|
||||
|
||||
3) Decentralized Distributed Proximal Policy Optimization (DDPPO)
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/latest/rllib-algorithms.html#decentralized-distributed-proximal-policy-optimization-dd-ppo)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ddppo.py)**
|
||||
|
|
@ -1,12 +1,31 @@
|
|||
"""
|
||||
Asynchronous Proximal Policy Optimization (APPO)
|
||||
================================================
|
||||
|
||||
This file defines the distributed Trainer class for the asynchronous version
|
||||
of proximal policy optimization (APPO).
|
||||
See `appo_[tf|torch]_policy.py` for the definition of the policy loss.
|
||||
|
||||
Detailed documentation:
|
||||
https://docs.ray.io/en/latest/rllib-algorithms.html#appo
|
||||
"""
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.impala.impala import validate_config
|
||||
from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy
|
||||
from ray.rllib.agents.ppo.ppo import UpdateKL
|
||||
from ray.rllib.agents import impala
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES, _get_shared_metrics
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the `IMPALATrainer` config in
|
||||
# rllib/agents/impala/impala.py.
|
||||
DEFAULT_CONFIG = impala.ImpalaTrainer.merge_trainer_configs(
|
||||
impala.DEFAULT_CONFIG, # See keys in impala.py, which are also supported.
|
||||
{
|
||||
|
@ -60,15 +79,11 @@ DEFAULT_CONFIG = impala.ImpalaTrainer.merge_trainer_configs(
|
|||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def initialize_target(trainer):
|
||||
trainer.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
|
||||
|
||||
class UpdateTargetAndKL:
|
||||
def __init__(self, workers, config):
|
||||
self.workers = workers
|
||||
|
@ -92,25 +107,47 @@ class UpdateTargetAndKL:
|
|||
self.update_kl(fetches)
|
||||
|
||||
|
||||
def add_target_callback(config):
|
||||
def add_target_callback(config: TrainerConfigDict):
|
||||
"""Add the update target and kl hook.
|
||||
|
||||
This hook is called explicitly after each learner step in the execution
|
||||
setup for IMPALA.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The APPO config dict.
|
||||
"""
|
||||
|
||||
config["after_train_step"] = UpdateTargetAndKL
|
||||
return validate_config(config)
|
||||
validate_config(config)
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config.get("framework") == "torch":
|
||||
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
"""Policy class picker function. Class is chosen based on DL-framework.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
Optional[Type[Policy]]: The Policy class to use with PPOTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.ppo.appo_torch_policy import AsyncPPOTorchPolicy
|
||||
return AsyncPPOTorchPolicy
|
||||
else:
|
||||
return AsyncPPOTFPolicy
|
||||
|
||||
|
||||
def initialize_target(trainer: Trainer) -> None:
|
||||
"""Updates target network on startup by synching it with the policy net.
|
||||
|
||||
Args:
|
||||
trainer (Trainer): The Trainer object.
|
||||
"""
|
||||
trainer.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
|
||||
|
||||
# Build a child class of `Trainer`, based on ImpalaTrainer's setup.
|
||||
# Note: The generated class is NOT a sub-class of ImpalaTrainer, but directly
|
||||
# of the `Trainer` class.
|
||||
APPOTrainer = impala.ImpalaTrainer.with_updates(
|
||||
name="APPO",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
|
|
|
@ -1,25 +1,34 @@
|
|||
"""Adapted from VTraceTFPolicy to use the PPO surrogate loss.
|
||||
"""
|
||||
TensorFlow policy class used for APPO.
|
||||
|
||||
Keep in sync with changes to VTraceTFPolicy."""
|
||||
Adapted from VTraceTFPolicy to use the PPO surrogate loss.
|
||||
Keep in sync with changes to VTraceTFPolicy.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
import gym
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from ray.rllib.agents.impala import vtrace_tf as vtrace
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
|
||||
clip_gradients, choose_optimizer
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule, TFPolicy
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import KLCoeffMixin, ValueNetworkMixin
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
||||
from ray.rllib.utils.typing import AgentID, TensorType, TrainerConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
@ -29,179 +38,26 @@ TARGET_POLICY_SCOPE = "target_func"
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PPOSurrogateLoss:
|
||||
"""Loss used when V-trace is disabled.
|
||||
def make_appo_model(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> ModelV2:
|
||||
"""Builds model and target model for APPO.
|
||||
|
||||
Arguments:
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
advantages: A float32 tensor of shape [T, B].
|
||||
value_targets: A float32 tensor of shape [T, B].
|
||||
vf_loss_coeff (float): Coefficient of the value function loss.
|
||||
entropy_coeff (float): Coefficient of the entropy regularizer.
|
||||
clip_param (float): Clip parameter.
|
||||
cur_kl_coeff (float): Coefficient for KL loss.
|
||||
use_kl_loss (bool): If true, use KL loss.
|
||||
Args:
|
||||
policy (Policy): The Policy, which will use the model for optimization.
|
||||
obs_space (gym.spaces.Space): The policy's observation space.
|
||||
action_space (gym.spaces.Space): The policy's action space.
|
||||
config (TrainerConfigDict):
|
||||
|
||||
Returns:
|
||||
ModelV2: The Model for the Policy to use.
|
||||
Note: The target model will not be returned, just assigned to
|
||||
`policy.target_model`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prev_actions_logp,
|
||||
actions_logp,
|
||||
action_kl,
|
||||
actions_entropy,
|
||||
values,
|
||||
valid_mask,
|
||||
advantages,
|
||||
value_targets,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
clip_param=0.3,
|
||||
cur_kl_coeff=None,
|
||||
use_kl_loss=False):
|
||||
def reduce_mean_valid(t):
|
||||
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
|
||||
|
||||
logp_ratio = tf.math.exp(actions_logp - prev_actions_logp)
|
||||
|
||||
surrogate_loss = tf.minimum(
|
||||
advantages * logp_ratio,
|
||||
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
||||
1 + clip_param))
|
||||
|
||||
self.mean_kl = reduce_mean_valid(action_kl)
|
||||
self.pi_loss = -reduce_mean_valid(surrogate_loss)
|
||||
|
||||
# The baseline loss
|
||||
delta = values - value_targets
|
||||
self.value_targets = value_targets
|
||||
self.vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
|
||||
|
||||
# The entropy loss
|
||||
self.entropy = reduce_mean_valid(actions_entropy)
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
# Optional additional KL Loss
|
||||
if use_kl_loss:
|
||||
self.total_loss += cur_kl_coeff * self.mean_kl
|
||||
|
||||
|
||||
class VTraceSurrogateLoss:
|
||||
def __init__(self,
|
||||
actions,
|
||||
prev_actions_logp,
|
||||
actions_logp,
|
||||
old_policy_actions_logp,
|
||||
action_kl,
|
||||
actions_entropy,
|
||||
dones,
|
||||
behaviour_logits,
|
||||
old_policy_behaviour_logits,
|
||||
target_logits,
|
||||
discount,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
dist_class,
|
||||
model,
|
||||
valid_mask,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
clip_param=0.3,
|
||||
cur_kl_coeff=None,
|
||||
use_kl_loss=False):
|
||||
"""APPO Loss, with IS modifications and V-trace for Advantage Estimation
|
||||
|
||||
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
|
||||
batch_size. The reason we need to know `B` is for V-trace to properly
|
||||
handle episode cut boundaries.
|
||||
|
||||
Arguments:
|
||||
actions: An int|float32 tensor of shape [T, B, logit_dim].
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
old_policy_actions_logp: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
dones: A bool tensor of shape [T, B].
|
||||
behaviour_logits: A float32 tensor of shape [T, B, logit_dim].
|
||||
old_policy_behaviour_logits: A float32 tensor of shape
|
||||
[T, B, logit_dim].
|
||||
target_logits: A float32 tensor of shape [T, B, logit_dim].
|
||||
discount: A float32 scalar.
|
||||
rewards: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
bootstrap_value: A float32 tensor of shape [B].
|
||||
dist_class: action distribution class for logits.
|
||||
model: backing ModelV2 instance
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
vf_loss_coeff (float): Coefficient of the value function loss.
|
||||
entropy_coeff (float): Coefficient of the entropy regularizer.
|
||||
clip_param (float): Clip parameter.
|
||||
cur_kl_coeff (float): Coefficient for KL loss.
|
||||
use_kl_loss (bool): If true, use KL loss.
|
||||
"""
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
|
||||
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
with tf.device("/cpu:0"):
|
||||
self.vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
target_policy_logits=old_policy_behaviour_logits,
|
||||
actions=tf.unstack(actions, axis=2),
|
||||
discounts=tf.cast(~dones, tf.float32) * discount,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
bootstrap_value=bootstrap_value,
|
||||
dist_class=dist_class,
|
||||
model=model,
|
||||
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
|
||||
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
||||
tf.float32))
|
||||
|
||||
self.is_ratio = tf.clip_by_value(
|
||||
tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
|
||||
logp_ratio = self.is_ratio * tf.exp(actions_logp - prev_actions_logp)
|
||||
|
||||
advantages = self.vtrace_returns.pg_advantages
|
||||
surrogate_loss = tf.minimum(
|
||||
advantages * logp_ratio,
|
||||
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
||||
1 + clip_param))
|
||||
|
||||
self.mean_kl = reduce_mean_valid(action_kl)
|
||||
self.pi_loss = -reduce_mean_valid(surrogate_loss)
|
||||
|
||||
# The baseline loss
|
||||
delta = values - self.vtrace_returns.vs
|
||||
self.value_targets = self.vtrace_returns.vs
|
||||
self.vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
|
||||
|
||||
# The entropy loss
|
||||
self.entropy = reduce_mean_valid(actions_entropy)
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
# Optional additional KL Loss
|
||||
if use_kl_loss:
|
||||
self.total_loss += cur_kl_coeff * self.mean_kl
|
||||
|
||||
|
||||
def build_appo_model(policy, obs_space, action_space, config):
|
||||
# Get the num_outputs for the following model construction calls.
|
||||
_, logit_dim = ModelCatalog.get_action_dist(action_space, config["model"])
|
||||
|
||||
# Construct the (main) model.
|
||||
policy.model = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
action_space,
|
||||
|
@ -211,6 +67,7 @@ def build_appo_model(policy, obs_space, action_space, config):
|
|||
framework="torch" if config["framework"] == "torch" else "tf")
|
||||
policy.model_variables = policy.model.variables()
|
||||
|
||||
# Construct the target model.
|
||||
policy.target_model = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
action_space,
|
||||
|
@ -220,10 +77,27 @@ def build_appo_model(policy, obs_space, action_space, config):
|
|||
framework="torch" if config["framework"] == "torch" else "tf")
|
||||
policy.target_model_variables = policy.target_model.variables()
|
||||
|
||||
# Return only the model (not the target model).
|
||||
return policy.model
|
||||
|
||||
|
||||
def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
||||
def appo_surrogate_loss(
|
||||
policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution],
|
||||
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
|
||||
"""Constructs the loss for APPO.
|
||||
|
||||
With IS modifications and V-trace for Advantage Estimation.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
dist_class (Type[ActionDistribution]: The action distr. class.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
Returns:
|
||||
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
||||
of loss tensors.
|
||||
"""
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
|
||||
|
@ -238,6 +112,7 @@ def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
|
||||
# TODO: (sven) deprecate this when trajectory view API gets activated.
|
||||
def make_time_major(*args, **kw):
|
||||
return _make_time_major(policy, train_batch.get("seq_lens"), *args,
|
||||
**kw)
|
||||
|
@ -248,16 +123,9 @@ def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
|
||||
|
||||
target_model_out, _ = policy.target_model.from_batch(train_batch)
|
||||
old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
|
||||
|
||||
unpacked_behaviour_logits = tf.split(
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
unpacked_old_policy_behaviour_logits = tf.split(
|
||||
old_policy_behaviour_logits, output_hidden_shape, axis=1)
|
||||
unpacked_outputs = tf.split(model_out, output_hidden_shape, axis=1)
|
||||
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
|
||||
prev_action_dist = dist_class(behaviour_logits, policy.model)
|
||||
values = policy.model.value_function()
|
||||
values_time_major = make_time_major(values)
|
||||
|
||||
policy.model_vars = policy.model.variables()
|
||||
policy.target_model_vars = policy.target_model.variables()
|
||||
|
@ -266,80 +134,151 @@ def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
max_seq_len = tf.reduce_max(train_batch["seq_lens"]) - 1
|
||||
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
mask = make_time_major(mask, drop_last=policy.config["vtrace"])
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return tf.reduce_mean(tf.boolean_mask(t, mask))
|
||||
|
||||
else:
|
||||
mask = tf.ones_like(rewards)
|
||||
reduce_mean_valid = tf.reduce_mean
|
||||
|
||||
if policy.config["vtrace"]:
|
||||
logger.debug("Using V-Trace surrogate loss (vtrace=True)")
|
||||
|
||||
# Prepare actions for loss
|
||||
# Prepare actions for loss.
|
||||
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
||||
actions, axis=1)
|
||||
|
||||
old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
|
||||
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
|
||||
|
||||
# Prepare KL for Loss
|
||||
mean_kl = make_time_major(
|
||||
old_policy_action_dist.multi_kl(action_dist), drop_last=True)
|
||||
|
||||
policy.loss = VTraceSurrogateLoss(
|
||||
actions=make_time_major(loss_actions, drop_last=True),
|
||||
prev_actions_logp=make_time_major(
|
||||
prev_action_dist.logp(actions), drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
old_policy_actions_logp=make_time_major(
|
||||
old_policy_action_dist.logp(actions), drop_last=True),
|
||||
action_kl=tf.reduce_mean(mean_kl, axis=0)
|
||||
if is_multidiscrete else mean_kl,
|
||||
actions_entropy=make_time_major(
|
||||
action_dist.multi_entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
old_policy_behaviour_logits=make_time_major(
|
||||
unpacked_old_policy_behaviour_logits, drop_last=True),
|
||||
target_logits=make_time_major(unpacked_outputs, drop_last=True),
|
||||
discount=policy.config["gamma"],
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
dist_class=Categorical if is_multidiscrete else dist_class,
|
||||
model=policy.model,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=policy.config[
|
||||
"vtrace_clip_pg_rho_threshold"],
|
||||
clip_param=policy.config["clip_param"],
|
||||
cur_kl_coeff=policy.kl_coeff,
|
||||
use_kl_loss=policy.config["use_kl_loss"])
|
||||
unpacked_behaviour_logits = tf.split(
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
unpacked_old_policy_behaviour_logits = tf.split(
|
||||
old_policy_behaviour_logits, output_hidden_shape, axis=1)
|
||||
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
with tf.device("/cpu:0"):
|
||||
vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_policy_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
target_policy_logits=make_time_major(
|
||||
unpacked_old_policy_behaviour_logits, drop_last=True),
|
||||
actions=tf.unstack(
|
||||
make_time_major(loss_actions, drop_last=True), axis=2),
|
||||
discounts=tf.cast(~make_time_major(dones, drop_last=True),
|
||||
tf.float32) * policy.config["gamma"],
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=values_time_major[:-1], # drop-last=True
|
||||
bootstrap_value=values_time_major[-1],
|
||||
dist_class=Categorical if is_multidiscrete else dist_class,
|
||||
model=model,
|
||||
clip_rho_threshold=tf.cast(
|
||||
policy.config["vtrace_clip_rho_threshold"], tf.float32),
|
||||
clip_pg_rho_threshold=tf.cast(
|
||||
policy.config["vtrace_clip_pg_rho_threshold"], tf.float32),
|
||||
)
|
||||
|
||||
actions_logp = make_time_major(
|
||||
action_dist.logp(actions), drop_last=True)
|
||||
prev_actions_logp = make_time_major(
|
||||
prev_action_dist.logp(actions), drop_last=True)
|
||||
old_policy_actions_logp = make_time_major(
|
||||
old_policy_action_dist.logp(actions), drop_last=True)
|
||||
|
||||
is_ratio = tf.clip_by_value(
|
||||
tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
|
||||
logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp)
|
||||
policy._is_ratio = is_ratio
|
||||
|
||||
advantages = vtrace_returns.pg_advantages
|
||||
surrogate_loss = tf.minimum(
|
||||
advantages * logp_ratio,
|
||||
advantages *
|
||||
tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"],
|
||||
1 + policy.config["clip_param"]))
|
||||
|
||||
action_kl = tf.reduce_mean(mean_kl, axis=0) \
|
||||
if is_multidiscrete else mean_kl
|
||||
mean_kl = reduce_mean_valid(action_kl)
|
||||
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
|
||||
|
||||
# The value function loss.
|
||||
delta = values_time_major[:-1] - vtrace_returns.vs
|
||||
value_targets = vtrace_returns.vs
|
||||
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
|
||||
|
||||
# The entropy loss.
|
||||
actions_entropy = make_time_major(
|
||||
action_dist.multi_entropy(), drop_last=True)
|
||||
mean_entropy = reduce_mean_valid(actions_entropy)
|
||||
|
||||
else:
|
||||
logger.debug("Using PPO surrogate loss (vtrace=False)")
|
||||
|
||||
# Prepare KL for Loss
|
||||
mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))
|
||||
|
||||
policy.loss = PPOSurrogateLoss(
|
||||
prev_actions_logp=make_time_major(prev_action_dist.logp(actions)),
|
||||
actions_logp=make_time_major(action_dist.logp(actions)),
|
||||
action_kl=tf.reduce_mean(mean_kl, axis=0)
|
||||
if is_multidiscrete else mean_kl,
|
||||
actions_entropy=make_time_major(action_dist.multi_entropy()),
|
||||
values=make_time_major(values),
|
||||
valid_mask=make_time_major(mask),
|
||||
advantages=make_time_major(train_batch[Postprocessing.ADVANTAGES]),
|
||||
value_targets=make_time_major(
|
||||
train_batch[Postprocessing.VALUE_TARGETS]),
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_param=policy.config["clip_param"],
|
||||
cur_kl_coeff=policy.kl_coeff,
|
||||
use_kl_loss=policy.config["use_kl_loss"])
|
||||
logp_ratio = tf.math.exp(
|
||||
make_time_major(action_dist.logp(actions)) -
|
||||
make_time_major(prev_action_dist.logp(actions)))
|
||||
|
||||
return policy.loss.total_loss
|
||||
advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES])
|
||||
surrogate_loss = tf.minimum(
|
||||
advantages * logp_ratio,
|
||||
advantages *
|
||||
tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"],
|
||||
1 + policy.config["clip_param"]))
|
||||
|
||||
action_kl = tf.reduce_mean(mean_kl, axis=0) \
|
||||
if is_multidiscrete else mean_kl
|
||||
mean_kl = reduce_mean_valid(action_kl)
|
||||
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
|
||||
|
||||
# The value function loss.
|
||||
value_targets = make_time_major(
|
||||
train_batch[Postprocessing.VALUE_TARGETS])
|
||||
delta = values_time_major - value_targets
|
||||
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
|
||||
|
||||
# The entropy loss.
|
||||
mean_entropy = reduce_mean_valid(
|
||||
make_time_major(action_dist.multi_entropy()))
|
||||
|
||||
# The summed weighted loss
|
||||
total_loss = mean_policy_loss + \
|
||||
mean_vf_loss * policy.config["vf_loss_coeff"] - \
|
||||
mean_entropy * policy.config["entropy_coeff"]
|
||||
|
||||
# Optional additional KL Loss
|
||||
if policy.config["use_kl_loss"]:
|
||||
total_loss += policy.kl_coeff * mean_kl
|
||||
|
||||
policy._total_loss = total_loss
|
||||
policy._mean_policy_loss = mean_policy_loss
|
||||
policy._mean_kl = mean_kl
|
||||
policy._mean_vf_loss = mean_vf_loss
|
||||
policy._mean_entropy = mean_entropy
|
||||
policy._value_targets = value_targets
|
||||
|
||||
# Store stats in policy for stats_fn.
|
||||
return total_loss
|
||||
|
||||
|
||||
def stats(policy, train_batch):
|
||||
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
"""Stats function for APPO. Returns a dict with important loss stats.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to generate stats for.
|
||||
train_batch (SampleBatch): The SampleBatch (already) used for training.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The stats dict.
|
||||
"""
|
||||
values_batched = _make_time_major(
|
||||
policy,
|
||||
train_batch.get("seq_lens"),
|
||||
|
@ -348,31 +287,55 @@ def stats(policy, train_batch):
|
|||
|
||||
stats_dict = {
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
"policy_loss": policy.loss.pi_loss,
|
||||
"entropy": policy.loss.entropy,
|
||||
"policy_loss": policy._mean_policy_loss,
|
||||
"entropy": policy._mean_entropy,
|
||||
"var_gnorm": tf.linalg.global_norm(policy.model.trainable_variables()),
|
||||
"vf_loss": policy.loss.vf_loss,
|
||||
"vf_loss": policy._mean_vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(policy.loss.value_targets, [-1]),
|
||||
tf.reshape(policy._value_targets, [-1]),
|
||||
tf.reshape(values_batched, [-1])),
|
||||
}
|
||||
|
||||
if policy.config["vtrace"]:
|
||||
is_stat_mean, is_stat_var = tf.nn.moments(policy.loss.is_ratio, [0, 1])
|
||||
stats_dict.update({"mean_IS": is_stat_mean})
|
||||
stats_dict.update({"var_IS": is_stat_var})
|
||||
is_stat_mean, is_stat_var = tf.nn.moments(policy._is_ratio, [0, 1])
|
||||
stats_dict["mean_IS"] = is_stat_mean
|
||||
stats_dict["var_IS"] = is_stat_var
|
||||
|
||||
if policy.config["use_kl_loss"]:
|
||||
stats_dict.update({"kl": policy.loss.mean_kl})
|
||||
stats_dict.update({"KL_Coeff": policy.kl_coeff})
|
||||
stats_dict["kl"] = policy._mean_kl
|
||||
stats_dict["KL_Coeff"] = policy.kl_coeff
|
||||
|
||||
return stats_dict
|
||||
|
||||
|
||||
def postprocess_trajectory(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
def postprocess_trajectory(
|
||||
policy: Policy,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
||||
"""Postprocesses a trajectory and returns the processed trajectory.
|
||||
|
||||
The trajectory contains only data from one episode and from one agent.
|
||||
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
||||
contain a truncated (at-the-end) episode, in case the
|
||||
`config.rollout_fragment_length` was reached by the sampler.
|
||||
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
||||
exactly one episode (no matter how long).
|
||||
New columns can be added to sample_batch and existing ones may be altered.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy used to generate the trajectory
|
||||
(`sample_batch`)
|
||||
sample_batch (SampleBatch): The SampleBatch to postprocess.
|
||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||
same episode). NOTE: The other agents use the same policy.
|
||||
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
||||
object in which the agents operated.
|
||||
|
||||
Returns:
|
||||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||
"""
|
||||
if not policy.config["vtrace"]:
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
|
@ -394,7 +357,10 @@ def postprocess_trajectory(policy,
|
|||
use_critic=policy.config["use_critic"])
|
||||
else:
|
||||
batch = sample_batch
|
||||
# TODO: (sven) remove this del once we have trajectory view API fully in
|
||||
# place.
|
||||
del batch.data["new_obs"] # not used, so save some bandwidth
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
|
@ -406,13 +372,14 @@ def add_values(policy):
|
|||
|
||||
|
||||
class TargetNetworkMixin:
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
"""Target Network is updated by the master learner every
|
||||
trainer.update_target_frequency steps. All worker batches
|
||||
are importance sampled w.r. to the target network to ensure
|
||||
a more stable pi_old in PPO.
|
||||
"""
|
||||
"""Target NN is updated by master learner via the `update_target` method.
|
||||
|
||||
Updates happen every `trainer.update_target_frequency` steps. All worker
|
||||
batches are importance sampled wrt the target network to ensure a more
|
||||
stable pi_old in PPO.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
@make_tf_callable(self.get_session())
|
||||
def do_update():
|
||||
assign_ops = []
|
||||
|
@ -429,20 +396,42 @@ class TargetNetworkMixin:
|
|||
return self.model_vars + self.target_model_vars
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Call all mixin classes' constructors before APPOPolicy initialization.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
KLCoeffMixin.__init__(policy, config)
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
|
||||
def setup_late_mixins(policy, obs_space, action_space, config):
|
||||
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Call all mixin classes' constructors after APPOPolicy initialization.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
|
||||
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
|
||||
# above.
|
||||
AsyncPPOTFPolicy = build_tf_policy(
|
||||
name="AsyncPPOTFPolicy",
|
||||
make_model=build_appo_model,
|
||||
loss_fn=build_appo_surrogate_loss,
|
||||
make_model=make_appo_model,
|
||||
loss_fn=appo_surrogate_loss,
|
||||
stats_fn=stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
optimizer_fn=choose_optimizer,
|
||||
|
|
|
@ -1,220 +1,58 @@
|
|||
"""Adapted from VTraceTFPolicy to use the PPO surrogate loss.
|
||||
"""
|
||||
PyTorch policy class used for APPO.
|
||||
|
||||
Keep in sync with changes to VTraceTFPolicy."""
|
||||
Adapted from VTraceTFPolicy to use the PPO surrogate loss.
|
||||
Keep in sync with changes to VTraceTFPolicy.
|
||||
"""
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import logging
|
||||
import gym
|
||||
from typing import Type
|
||||
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
||||
from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \
|
||||
choose_optimizer
|
||||
from ray.rllib.agents.ppo.appo_tf_policy import build_appo_model, \
|
||||
from ray.rllib.agents.ppo.appo_tf_policy import make_appo_model, \
|
||||
postprocess_trajectory
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin, \
|
||||
KLCoeffMixin
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import \
|
||||
TorchDistributionWrapper, TorchCategorical
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import LearningRateSchedule
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
|
||||
sequence_mask
|
||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PPOSurrogateLoss:
|
||||
"""Loss used when V-trace is disabled.
|
||||
def appo_surrogate_loss(policy: Policy, model: ModelV2,
|
||||
dist_class: Type[TorchDistributionWrapper],
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
"""Constructs the loss for APPO.
|
||||
|
||||
Arguments:
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
advantages: A float32 tensor of shape [T, B].
|
||||
value_targets: A float32 tensor of shape [T, B].
|
||||
vf_loss_coeff (float): Coefficient of the value function loss.
|
||||
entropy_coeff (float): Coefficient of the entropy regularizer.
|
||||
clip_param (float): Clip parameter.
|
||||
cur_kl_coeff (float): Coefficient for KL loss.
|
||||
use_kl_loss (bool): If true, use KL loss.
|
||||
With IS modifications and V-trace for Advantage Estimation.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
dist_class (Type[ActionDistribution]: The action distr. class.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
Returns:
|
||||
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
||||
of loss tensors.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prev_actions_logp,
|
||||
actions_logp,
|
||||
action_kl,
|
||||
actions_entropy,
|
||||
values,
|
||||
valid_mask,
|
||||
advantages,
|
||||
value_targets,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
clip_param=0.3,
|
||||
cur_kl_coeff=None,
|
||||
use_kl_loss=False):
|
||||
|
||||
if valid_mask is not None:
|
||||
num_valid = torch.sum(valid_mask)
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return torch.sum(t * valid_mask) / num_valid
|
||||
|
||||
else:
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return torch.mean(t)
|
||||
|
||||
logp_ratio = torch.exp(actions_logp - prev_actions_logp)
|
||||
|
||||
surrogate_loss = torch.min(
|
||||
advantages * logp_ratio,
|
||||
advantages * torch.clamp(logp_ratio, 1 - clip_param,
|
||||
1 + clip_param))
|
||||
|
||||
self.mean_kl = reduce_mean_valid(action_kl)
|
||||
self.pi_loss = -reduce_mean_valid(surrogate_loss)
|
||||
|
||||
# The baseline loss
|
||||
delta = values - value_targets
|
||||
self.value_targets = value_targets
|
||||
self.vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
|
||||
|
||||
# The entropy loss
|
||||
self.entropy = reduce_mean_valid(actions_entropy)
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
# Optional additional KL Loss
|
||||
if use_kl_loss:
|
||||
self.total_loss += cur_kl_coeff * self.mean_kl
|
||||
|
||||
|
||||
class VTraceSurrogateLoss:
|
||||
def __init__(self,
|
||||
actions,
|
||||
prev_actions_logp,
|
||||
actions_logp,
|
||||
old_policy_actions_logp,
|
||||
action_kl,
|
||||
actions_entropy,
|
||||
dones,
|
||||
behaviour_logits,
|
||||
old_policy_behaviour_logits,
|
||||
target_logits,
|
||||
discount,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
dist_class,
|
||||
model,
|
||||
valid_mask,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
clip_param=0.3,
|
||||
cur_kl_coeff=None,
|
||||
use_kl_loss=False):
|
||||
"""APPO Loss, with IS modifications and V-trace for Advantage Estimation
|
||||
|
||||
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
|
||||
batch_size. The reason we need to know `B` is for V-trace to properly
|
||||
handle episode cut boundaries.
|
||||
|
||||
Arguments:
|
||||
actions: An int|float32 tensor of shape [T, B, logit_dim].
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
old_policy_actions_logp: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
dones: A bool tensor of shape [T, B].
|
||||
behaviour_logits: A float32 tensor of shape [T, B, logit_dim].
|
||||
old_policy_behaviour_logits: A float32 tensor of shape
|
||||
[T, B, logit_dim].
|
||||
target_logits: A float32 tensor of shape [T, B, logit_dim].
|
||||
discount: A float32 scalar.
|
||||
rewards: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
bootstrap_value: A float32 tensor of shape [B].
|
||||
dist_class: action distribution class for logits.
|
||||
model: backing ModelV2 instance
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
vf_loss_coeff (float): Coefficient of the value function loss.
|
||||
entropy_coeff (float): Coefficient of the entropy regularizer.
|
||||
clip_param (float): Clip parameter.
|
||||
cur_kl_coeff (float): Coefficient for KL loss.
|
||||
use_kl_loss (bool): If true, use KL loss.
|
||||
"""
|
||||
|
||||
if valid_mask is not None:
|
||||
num_valid = torch.sum(valid_mask)
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return torch.sum(t * valid_mask) / num_valid
|
||||
|
||||
else:
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return torch.mean(t)
|
||||
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
self.vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
target_policy_logits=old_policy_behaviour_logits,
|
||||
actions=torch.unbind(actions, dim=2),
|
||||
discounts=(1.0 - dones.float()) * discount,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
bootstrap_value=bootstrap_value,
|
||||
dist_class=dist_class,
|
||||
model=model,
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold)
|
||||
|
||||
self.is_ratio = torch.clamp(
|
||||
torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
|
||||
logp_ratio = self.is_ratio * torch.exp(actions_logp -
|
||||
prev_actions_logp)
|
||||
|
||||
advantages = self.vtrace_returns.pg_advantages
|
||||
surrogate_loss = torch.min(
|
||||
advantages * logp_ratio,
|
||||
advantages * torch.clamp(logp_ratio, 1 - clip_param,
|
||||
1 + clip_param))
|
||||
|
||||
self.mean_kl = reduce_mean_valid(action_kl)
|
||||
self.pi_loss = -reduce_mean_valid(surrogate_loss)
|
||||
|
||||
# The baseline loss
|
||||
delta = values - self.vtrace_returns.vs
|
||||
self.value_targets = self.vtrace_returns.vs
|
||||
self.vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
|
||||
|
||||
# The entropy loss
|
||||
self.entropy = reduce_mean_valid(actions_entropy)
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
# Optional additional KL Loss
|
||||
if use_kl_loss:
|
||||
self.total_loss += cur_kl_coeff * self.mean_kl
|
||||
|
||||
|
||||
def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
|
||||
|
@ -239,25 +77,10 @@ def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
|
||||
|
||||
target_model_out, _ = policy.target_model.from_batch(train_batch)
|
||||
old_policy_behaviour_logits = target_model_out.detach()
|
||||
|
||||
if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
|
||||
unpacked_behaviour_logits = torch.split(
|
||||
behaviour_logits, list(output_hidden_shape), dim=1)
|
||||
unpacked_old_policy_behaviour_logits = torch.split(
|
||||
old_policy_behaviour_logits, list(output_hidden_shape), dim=1)
|
||||
unpacked_outputs = torch.split(
|
||||
model_out, list(output_hidden_shape), dim=1)
|
||||
else:
|
||||
unpacked_behaviour_logits = torch.chunk(
|
||||
behaviour_logits, output_hidden_shape, dim=1)
|
||||
unpacked_old_policy_behaviour_logits = torch.chunk(
|
||||
old_policy_behaviour_logits, output_hidden_shape, dim=1)
|
||||
unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1)
|
||||
|
||||
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
|
||||
prev_action_dist = dist_class(behaviour_logits, policy.model)
|
||||
values = policy.model.value_function()
|
||||
values_time_major = _make_time_major(values)
|
||||
|
||||
policy.model_vars = policy.model.variables()
|
||||
policy.target_model_vars = policy.target_model.variables()
|
||||
|
@ -266,79 +89,147 @@ def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
max_seq_len = torch.max(train_batch["seq_lens"]) - 1
|
||||
mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
|
||||
mask = torch.reshape(mask, [-1])
|
||||
num_valid = torch.sum(mask)
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return torch.sum(t * mask) / num_valid
|
||||
|
||||
else:
|
||||
mask = torch.ones_like(rewards)
|
||||
reduce_mean_valid = torch.mean
|
||||
|
||||
if policy.config["vtrace"]:
|
||||
logger.debug("Using V-Trace surrogate loss (vtrace=True)")
|
||||
|
||||
old_policy_behaviour_logits = target_model_out.detach()
|
||||
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
|
||||
|
||||
if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
|
||||
unpacked_behaviour_logits = torch.split(
|
||||
behaviour_logits, list(output_hidden_shape), dim=1)
|
||||
unpacked_old_policy_behaviour_logits = torch.split(
|
||||
old_policy_behaviour_logits, list(output_hidden_shape), dim=1)
|
||||
else:
|
||||
unpacked_behaviour_logits = torch.chunk(
|
||||
behaviour_logits, output_hidden_shape, dim=1)
|
||||
unpacked_old_policy_behaviour_logits = torch.chunk(
|
||||
old_policy_behaviour_logits, output_hidden_shape, dim=1)
|
||||
|
||||
# Prepare actions for loss
|
||||
loss_actions = actions if is_multidiscrete else torch.unsqueeze(
|
||||
actions, dim=1)
|
||||
|
||||
# Prepare KL for Loss
|
||||
mean_kl = _make_time_major(
|
||||
action_kl = _make_time_major(
|
||||
old_policy_action_dist.kl(action_dist), drop_last=True)
|
||||
|
||||
policy.loss = VTraceSurrogateLoss(
|
||||
actions=_make_time_major(loss_actions, drop_last=True),
|
||||
prev_actions_logp=_make_time_major(
|
||||
prev_action_dist.logp(actions), drop_last=True),
|
||||
actions_logp=_make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
old_policy_actions_logp=_make_time_major(
|
||||
old_policy_action_dist.logp(actions), drop_last=True),
|
||||
action_kl=mean_kl,
|
||||
actions_entropy=_make_time_major(
|
||||
action_dist.entropy(), drop_last=True),
|
||||
dones=_make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=_make_time_major(
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_policy_logits=_make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
old_policy_behaviour_logits=_make_time_major(
|
||||
target_policy_logits=_make_time_major(
|
||||
unpacked_old_policy_behaviour_logits, drop_last=True),
|
||||
target_logits=_make_time_major(unpacked_outputs, drop_last=True),
|
||||
discount=policy.config["gamma"],
|
||||
actions=torch.unbind(
|
||||
_make_time_major(loss_actions, drop_last=True), dim=2),
|
||||
discounts=(1.0 - _make_time_major(dones, drop_last=True).float()) *
|
||||
policy.config["gamma"],
|
||||
rewards=_make_time_major(rewards, drop_last=True),
|
||||
values=_make_time_major(values, drop_last=True),
|
||||
bootstrap_value=_make_time_major(values)[-1],
|
||||
values=values_time_major[:-1], # drop-last=True
|
||||
bootstrap_value=values_time_major[-1],
|
||||
dist_class=TorchCategorical if is_multidiscrete else dist_class,
|
||||
model=policy.model,
|
||||
valid_mask=_make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
model=model,
|
||||
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=policy.config[
|
||||
"vtrace_clip_pg_rho_threshold"],
|
||||
clip_param=policy.config["clip_param"],
|
||||
cur_kl_coeff=policy.kl_coeff,
|
||||
use_kl_loss=policy.config["use_kl_loss"])
|
||||
"vtrace_clip_pg_rho_threshold"])
|
||||
|
||||
actions_logp = _make_time_major(
|
||||
action_dist.logp(actions), drop_last=True)
|
||||
prev_actions_logp = _make_time_major(
|
||||
prev_action_dist.logp(actions), drop_last=True)
|
||||
old_policy_actions_logp = _make_time_major(
|
||||
old_policy_action_dist.logp(actions), drop_last=True)
|
||||
is_ratio = torch.clamp(
|
||||
torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
|
||||
logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
|
||||
policy._is_ratio = is_ratio
|
||||
|
||||
advantages = vtrace_returns.pg_advantages
|
||||
surrogate_loss = torch.min(
|
||||
advantages * logp_ratio,
|
||||
advantages *
|
||||
torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
|
||||
1 + policy.config["clip_param"]))
|
||||
|
||||
mean_kl = reduce_mean_valid(action_kl)
|
||||
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
|
||||
|
||||
# The value function loss.
|
||||
delta = values_time_major[:-1] - vtrace_returns.vs
|
||||
value_targets = vtrace_returns.vs
|
||||
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
|
||||
|
||||
# The entropy loss.
|
||||
mean_entropy = reduce_mean_valid(
|
||||
_make_time_major(action_dist.entropy(), drop_last=True))
|
||||
|
||||
else:
|
||||
logger.debug("Using PPO surrogate loss (vtrace=False)")
|
||||
|
||||
# Prepare KL for Loss
|
||||
mean_kl = _make_time_major(prev_action_dist.kl(action_dist))
|
||||
action_kl = _make_time_major(prev_action_dist.kl(action_dist))
|
||||
|
||||
policy.loss = PPOSurrogateLoss(
|
||||
prev_actions_logp=_make_time_major(prev_action_dist.logp(actions)),
|
||||
actions_logp=_make_time_major(action_dist.logp(actions)),
|
||||
action_kl=mean_kl,
|
||||
actions_entropy=_make_time_major(action_dist.entropy()),
|
||||
values=_make_time_major(values),
|
||||
valid_mask=_make_time_major(mask),
|
||||
advantages=_make_time_major(
|
||||
train_batch[Postprocessing.ADVANTAGES]),
|
||||
value_targets=_make_time_major(
|
||||
train_batch[Postprocessing.VALUE_TARGETS]),
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
clip_param=policy.config["clip_param"],
|
||||
cur_kl_coeff=policy.kl_coeff,
|
||||
use_kl_loss=policy.config["use_kl_loss"])
|
||||
actions_logp = _make_time_major(action_dist.logp(actions))
|
||||
prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
|
||||
logp_ratio = torch.exp(actions_logp - prev_actions_logp)
|
||||
|
||||
return policy.loss.total_loss
|
||||
advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES])
|
||||
surrogate_loss = torch.min(
|
||||
advantages * logp_ratio,
|
||||
advantages *
|
||||
torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
|
||||
1 + policy.config["clip_param"]))
|
||||
|
||||
mean_kl = reduce_mean_valid(action_kl)
|
||||
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
|
||||
|
||||
# The value function loss.
|
||||
value_targets = _make_time_major(
|
||||
train_batch[Postprocessing.VALUE_TARGETS])
|
||||
delta = values_time_major - value_targets
|
||||
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
|
||||
|
||||
# The entropy loss.
|
||||
mean_entropy = reduce_mean_valid(
|
||||
_make_time_major(action_dist.entropy()))
|
||||
|
||||
# The summed weighted loss
|
||||
total_loss = mean_policy_loss + \
|
||||
mean_vf_loss * policy.config["vf_loss_coeff"] - \
|
||||
mean_entropy * policy.config["entropy_coeff"]
|
||||
|
||||
# Optional additional KL Loss
|
||||
if policy.config["use_kl_loss"]:
|
||||
total_loss += policy.kl_coeff * mean_kl
|
||||
|
||||
policy._total_loss = total_loss
|
||||
policy._mean_policy_loss = mean_policy_loss
|
||||
policy._mean_kl = mean_kl
|
||||
policy._mean_vf_loss = mean_vf_loss
|
||||
policy._mean_entropy = mean_entropy
|
||||
policy._value_targets = value_targets
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def stats(policy, train_batch):
|
||||
def stats(policy: Policy, train_batch: SampleBatch):
|
||||
"""Stats function for APPO. Returns a dict with important loss stats.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to generate stats for.
|
||||
train_batch (SampleBatch): The SampleBatch (already) used for training.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The stats dict.
|
||||
"""
|
||||
values_batched = make_time_major(
|
||||
policy,
|
||||
train_batch.get("seq_lens"),
|
||||
|
@ -347,29 +238,36 @@ def stats(policy, train_batch):
|
|||
|
||||
stats_dict = {
|
||||
"cur_lr": policy.cur_lr,
|
||||
"policy_loss": policy.loss.pi_loss,
|
||||
"entropy": policy.loss.entropy,
|
||||
"policy_loss": policy._mean_policy_loss,
|
||||
"entropy": policy._mean_entropy,
|
||||
"var_gnorm": global_norm(policy.model.trainable_variables()),
|
||||
"vf_loss": policy.loss.vf_loss,
|
||||
"vf_loss": policy._mean_vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
torch.reshape(policy.loss.value_targets, [-1]),
|
||||
torch.reshape(policy._value_targets, [-1]),
|
||||
torch.reshape(values_batched, [-1])),
|
||||
}
|
||||
|
||||
if policy.config["vtrace"]:
|
||||
is_stat_mean = torch.mean(policy.loss.is_ratio, [0, 1])
|
||||
is_stat_var = torch.var(policy.loss.is_ratio, [0, 1])
|
||||
is_stat_mean = torch.mean(policy._is_ratio, [0, 1])
|
||||
is_stat_var = torch.var(policy._is_ratio, [0, 1])
|
||||
stats_dict.update({"mean_IS": is_stat_mean})
|
||||
stats_dict.update({"var_IS": is_stat_var})
|
||||
|
||||
if policy.config["use_kl_loss"]:
|
||||
stats_dict.update({"kl": policy.loss.mean_kl})
|
||||
stats_dict.update({"kl": policy._mean_kl})
|
||||
stats_dict.update({"KL_Coeff": policy.kl_coeff})
|
||||
|
||||
return stats_dict
|
||||
|
||||
|
||||
class TargetNetworkMixin:
|
||||
"""Target NN is updated by master learner via the `update_target` method.
|
||||
|
||||
Updates happen every `trainer.update_target_frequency` steps. All worker
|
||||
batches are importance sampled wrt the target network to ensure a more
|
||||
stable pi_old in PPO.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
def do_update():
|
||||
# Update_target_fn will be called periodically to copy Q network to
|
||||
|
@ -389,19 +287,41 @@ def add_values(policy, input_dict, state_batches, model, action_dist):
|
|||
return out
|
||||
|
||||
|
||||
def setup_early_mixins(policy, obs_space, action_space, config):
|
||||
def setup_early_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict):
|
||||
"""Call all mixin classes' constructors before APPOPolicy initialization.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
def setup_late_mixins(policy, obs_space, action_space, config):
|
||||
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict):
|
||||
"""Call all mixin classes' constructors after APPOPolicy initialization.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
KLCoeffMixin.__init__(policy, config)
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
|
||||
# Build a child class of `TorchPolicy`, given the custom functions defined
|
||||
# above.
|
||||
AsyncPPOTorchPolicy = build_torch_policy(
|
||||
name="AsyncPPOTorchPolicy",
|
||||
loss_fn=build_appo_surrogate_loss,
|
||||
loss_fn=appo_surrogate_loss,
|
||||
stats_fn=stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
extra_action_out_fn=add_values,
|
||||
|
@ -409,7 +329,7 @@ AsyncPPOTorchPolicy = build_torch_policy(
|
|||
optimizer_fn=choose_optimizer,
|
||||
before_init=setup_early_mixins,
|
||||
after_init=setup_late_mixins,
|
||||
make_model=build_appo_model,
|
||||
make_model=make_appo_model,
|
||||
mixins=[
|
||||
LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin,
|
||||
ValueNetworkMixin
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Decentralized Distributed PPO implementation.
|
||||
"""
|
||||
Decentralized Distributed PPO (DD-PPO)
|
||||
======================================
|
||||
|
||||
Unlike APPO or PPO, learning is no longer done centralized in the trainer
|
||||
process. Instead, gradients are computed remotely on each rollout worker and
|
||||
|
@ -19,6 +21,7 @@ import time
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import ppo
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
|
@ -26,11 +29,16 @@ from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
|||
_get_shared_metrics, _get_global_vars
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the `PPOTrainer` config in
|
||||
# rllib/agents/ppo/ppo.py.
|
||||
DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs(
|
||||
ppo.DEFAULT_CONFIG,
|
||||
{
|
||||
|
@ -67,36 +75,64 @@ DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs(
|
|||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
"""Validates the Trainer's config dict.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The Trainer's config to check.
|
||||
|
||||
Throws:
|
||||
ValueError: In case something is wrong with the config.
|
||||
"""
|
||||
|
||||
# Auto-train_batch_size: Calculate from rollout len and envs-per-worker.
|
||||
if config["train_batch_size"] == -1:
|
||||
# Auto set.
|
||||
config["train_batch_size"] = (
|
||||
config["rollout_fragment_length"] * config["num_envs_per_worker"])
|
||||
# Users should not define `train_batch_size` directly (always -1).
|
||||
else:
|
||||
raise ValueError(
|
||||
"Set rollout_fragment_length instead of train_batch_size "
|
||||
"for DDPPO.")
|
||||
|
||||
# Only supported for PyTorch so far.
|
||||
if config["framework"] != "torch":
|
||||
raise ValueError(
|
||||
"Distributed data parallel is only supported for PyTorch")
|
||||
# `num_gpus` must be 0/None, since all optimization happens on Workers.
|
||||
if config["num_gpus"]:
|
||||
raise ValueError(
|
||||
"When using distributed data parallel, you should set "
|
||||
"num_gpus=0 since all optimization "
|
||||
"is happening on workers. Enable GPUs for workers by setting "
|
||||
"num_gpus_per_worker=1.")
|
||||
# `batch_mode` must be "truncate_episodes".
|
||||
if config["batch_mode"] != "truncate_episodes":
|
||||
raise ValueError(
|
||||
"Distributed data parallel requires truncate_episodes "
|
||||
"batch mode.")
|
||||
# Call (base) PPO's config validation function.
|
||||
ppo.validate_config(config)
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
"""Execution plan of the DD-PPO algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
||||
of the Trainer.
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
LocalIterator[dict]: The Policy class to use with PGTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
rollouts = ParallelRollouts(workers, mode="raw")
|
||||
|
||||
# Setup the distributed processes.
|
||||
|
@ -194,8 +230,12 @@ def execution_plan(workers, config):
|
|||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
# Build a child class of `Trainer`, based on PPOTrainer's setup.
|
||||
# Note: The generated class is NOT a sub-class of PPOTrainer, but directly of
|
||||
# the `Trainer` class.
|
||||
DDPPOTrainer = ppo.PPOTrainer.with_updates(
|
||||
name="DDPPO",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
validate_config=validate_config,
|
||||
execution_plan=execution_plan,
|
||||
validate_config=validate_config)
|
||||
)
|
||||
|
|
|
@ -1,17 +1,36 @@
|
|||
"""
|
||||
Proximal Policy Optimization (PPO)
|
||||
==================================
|
||||
|
||||
This file defines the distributed Trainer class for proximal policy
|
||||
optimization.
|
||||
See `ppo_[tf|torch]_policy.py` for the definition of the policy loss.
|
||||
|
||||
Detailed documentation: https://docs.ray.io/en/latest/rllib-algorithms.html#ppo
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, \
|
||||
StandardizeFields, SelectExperiences
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, TrainTFMultiGPU
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the (base) `Trainer` config in
|
||||
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# Should use a critic as a baseline (otherwise don't use value baseline;
|
||||
# required for using GAE).
|
||||
|
@ -71,10 +90,103 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Set this to True for debugging on non-GPU machines (set `num_gpus` > 0).
|
||||
"_fake_gpus": False,
|
||||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def validate_config(config: TrainerConfigDict) -> None:
|
||||
"""Validates the Trainer's config dict.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The Trainer's config to check.
|
||||
|
||||
Throws:
|
||||
ValueError: In case something is wrong with the config.
|
||||
"""
|
||||
if isinstance(config["entropy_coeff"], int):
|
||||
config["entropy_coeff"] = float(config["entropy_coeff"])
|
||||
|
||||
if config["entropy_coeff"] < 0.0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0.0")
|
||||
|
||||
# SGD minibatch size must be smaller than train_batch_size (b/c
|
||||
# we subsample a batch of `sgd_minibatch_size` from the train-batch for
|
||||
# each `sgd_num_iter`).
|
||||
if config["sgd_minibatch_size"] > config["train_batch_size"]:
|
||||
raise ValueError("`sgd_minibatch_size` ({}) must be <= "
|
||||
"`train_batch_size` ({}).".format(
|
||||
config["sgd_minibatch_size"],
|
||||
config["train_batch_size"]))
|
||||
|
||||
# Episodes may only be truncated (and passed into PPO's
|
||||
# `postprocessing_fn`), iff generalized advantage estimation is used
|
||||
# (value function estimate at end of truncated episode to estimate
|
||||
# remaining value).
|
||||
if config["batch_mode"] == "truncate_episodes" and not config["use_gae"]:
|
||||
raise ValueError(
|
||||
"Episode truncation is not supported without a value "
|
||||
"function. Consider setting batch_mode=complete_episodes.")
|
||||
|
||||
# Multi-gpu not supported for PyTorch and tf-eager.
|
||||
if config["framework"] in ["tf2", "tfe", "torch"]:
|
||||
config["simple_optimizer"] = True
|
||||
# Performance warning, if "simple" optimizer used with (static-graph) tf.
|
||||
elif config["simple_optimizer"]:
|
||||
logger.warning(
|
||||
"Using the simple minibatch optimizer. This will significantly "
|
||||
"reduce performance, consider simple_optimizer=False.")
|
||||
# Multi-agent mode and multi-GPU optimizer.
|
||||
elif config["multiagent"]["policies"] and not config["simple_optimizer"]:
|
||||
logger.info(
|
||||
"In multi-agent mode, policies will be optimized sequentially "
|
||||
"by the multi-GPU optimizer. Consider setting "
|
||||
"simple_optimizer=True if this doesn't work for you.")
|
||||
|
||||
|
||||
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
"""Policy class picker function. Class is chosen based on DL-framework.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
Optional[Type[Policy]]: The Policy class to use with PPOTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
return PPOTorchPolicy
|
||||
|
||||
|
||||
class UpdateKL:
|
||||
"""Callback to update the KL based on optimization info.
|
||||
|
||||
This is used inside the execution_plan function. The Policy must define
|
||||
a `update_kl` method for this to work. This is achieved for PPO via a
|
||||
Policy mixin class (which adds the `update_kl` method),
|
||||
defined in ppo_[tf|torch]_policy.py.
|
||||
"""
|
||||
|
||||
def __init__(self, workers):
|
||||
self.workers = workers
|
||||
|
||||
def __call__(self, fetches):
|
||||
def update(pi, pi_id):
|
||||
assert "kl" not in fetches, (
|
||||
"kl should be nested under policy id key", fetches)
|
||||
if pi_id in fetches:
|
||||
assert "kl" in fetches[pi_id], (fetches, pi_id)
|
||||
# Make the actual `Policy.update_kl()` call.
|
||||
pi.update_kl(fetches[pi_id]["kl"])
|
||||
else:
|
||||
logger.warning("No data for {}, not updating kl".format(pi_id))
|
||||
|
||||
# Update KL on all trainable policies within the local (trainer)
|
||||
# Worker.
|
||||
self.workers.local_worker().foreach_trainable_policy(update)
|
||||
|
||||
|
||||
def warn_about_bad_reward_scales(config, result):
|
||||
if result["policy_reward_mean"]:
|
||||
return result # Punt on handling multiagent case.
|
||||
|
@ -111,71 +223,31 @@ def warn_about_bad_reward_scales(config, result):
|
|||
return result
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
if isinstance(config["entropy_coeff"], int):
|
||||
config["entropy_coeff"] = float(config["entropy_coeff"])
|
||||
if config["sgd_minibatch_size"] > config["train_batch_size"]:
|
||||
raise ValueError("`sgd_minibatch_size` ({}) must be <= "
|
||||
"`train_batch_size` ({}).".format(
|
||||
config["sgd_minibatch_size"],
|
||||
config["train_batch_size"]))
|
||||
if config["batch_mode"] == "truncate_episodes" and not config["use_gae"]:
|
||||
raise ValueError(
|
||||
"Episode truncation is not supported without a value "
|
||||
"function. Consider setting batch_mode=complete_episodes.")
|
||||
if config["multiagent"]["policies"] and not config["simple_optimizer"]:
|
||||
logger.info(
|
||||
"In multi-agent mode, policies will be optimized sequentially "
|
||||
"by the multi-GPU optimizer. Consider setting "
|
||||
"simple_optimizer=True if this doesn't work for you.")
|
||||
if config["simple_optimizer"]:
|
||||
logger.warning(
|
||||
"Using the simple minibatch optimizer. This will significantly "
|
||||
"reduce performance, consider simple_optimizer=False.")
|
||||
# Multi-gpu not supported for PyTorch and tf-eager.
|
||||
elif config["framework"] in ["tf2", "tfe", "torch"]:
|
||||
config["simple_optimizer"] = True
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
"""Execution plan of the PPO algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
||||
of the Trainer.
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
return PPOTorchPolicy
|
||||
else:
|
||||
return PPOTFPolicy
|
||||
|
||||
|
||||
class UpdateKL:
|
||||
"""Callback to update the KL based on optimization info."""
|
||||
|
||||
def __init__(self, workers):
|
||||
self.workers = workers
|
||||
|
||||
def __call__(self, fetches):
|
||||
def update(pi, pi_id):
|
||||
assert "kl" not in fetches, (
|
||||
"kl should be nested under policy id key", fetches)
|
||||
if pi_id in fetches:
|
||||
assert "kl" in fetches[pi_id], (fetches, pi_id)
|
||||
pi.update_kl(fetches[pi_id]["kl"])
|
||||
else:
|
||||
logger.warning("No data for {}, not updating kl".format(pi_id))
|
||||
|
||||
self.workers.local_worker().foreach_trainable_policy(update)
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
Returns:
|
||||
LocalIterator[dict]: The Policy class to use with PPOTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# Collect large batches of relevant experiences & standardize.
|
||||
# Collect batches for the trainable policies.
|
||||
rollouts = rollouts.for_each(
|
||||
SelectExperiences(workers.trainable_policies()))
|
||||
# Concatenate the SampleBatches into one.
|
||||
rollouts = rollouts.combine(
|
||||
ConcatBatches(min_batch_size=config["train_batch_size"]))
|
||||
# Standardize advantages.
|
||||
rollouts = rollouts.for_each(StandardizeFields(["advantages"]))
|
||||
|
||||
# Perform one training step on the combined + standardized batch.
|
||||
if config["simple_optimizer"]:
|
||||
train_op = rollouts.for_each(
|
||||
TrainOneStep(
|
||||
|
@ -199,14 +271,17 @@ def execution_plan(workers, config):
|
|||
# Update KL after each round of training.
|
||||
train_op = train_op.for_each(lambda t: t[1]).for_each(UpdateKL(workers))
|
||||
|
||||
# Warn about bad reward scales and return training metrics.
|
||||
return StandardMetricsReporting(train_op, workers, config) \
|
||||
.for_each(lambda result: warn_about_bad_reward_scales(config, result))
|
||||
|
||||
|
||||
# Build a child class of `Trainer`, which uses the framework specific Policy
|
||||
# determined in `get_policy_class()` above.
|
||||
PPOTrainer = build_trainer(
|
||||
name="PPO",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
validate_config=validate_config,
|
||||
default_policy=PPOTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
execution_plan=execution_plan,
|
||||
validate_config=validate_config)
|
||||
execution_plan=execution_plan)
|
||||
|
|
|
@ -1,176 +1,192 @@
|
|||
"""
|
||||
TensorFlow policy class used for PPO.
|
||||
"""
|
||||
|
||||
import gym
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule, \
|
||||
EntropyCoeffSchedule
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils.framework import try_import_tf, get_variable
|
||||
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
||||
from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \
|
||||
TensorType, TrainerConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PPOLoss:
|
||||
def __init__(self,
|
||||
dist_class,
|
||||
model,
|
||||
value_targets,
|
||||
advantages,
|
||||
actions,
|
||||
prev_logits,
|
||||
prev_actions_logp,
|
||||
vf_preds,
|
||||
curr_action_dist,
|
||||
value_fn,
|
||||
cur_kl_coeff,
|
||||
valid_mask,
|
||||
entropy_coeff=0,
|
||||
clip_param=0.1,
|
||||
vf_clip_param=0.1,
|
||||
vf_loss_coeff=1.0,
|
||||
use_gae=True):
|
||||
"""Constructs the loss for Proximal Policy Objective.
|
||||
def ppo_surrogate_loss(
|
||||
policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution],
|
||||
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
|
||||
"""Constructs the loss for Proximal Policy Objective.
|
||||
|
||||
Arguments:
|
||||
dist_class: action distribution class for logits.
|
||||
value_targets (Placeholder): Placeholder for target values; used
|
||||
for GAE.
|
||||
actions (Placeholder): Placeholder for actions taken
|
||||
from previous model evaluation.
|
||||
advantages (Placeholder): Placeholder for calculated advantages
|
||||
from previous model evaluation.
|
||||
prev_logits (Placeholder): Placeholder for logits output from
|
||||
previous model evaluation.
|
||||
prev_actions_logp (Placeholder): Placeholder for action prob output
|
||||
from the previous (before update) Model evaluation.
|
||||
vf_preds (Placeholder): Placeholder for value function output
|
||||
from the previous (before update) Model evaluation.
|
||||
curr_action_dist (ActionDistribution): ActionDistribution
|
||||
of the current model.
|
||||
value_fn (Tensor): Current value function output Tensor.
|
||||
cur_kl_coeff (Variable): Variable holding the current PPO KL
|
||||
coefficient.
|
||||
valid_mask (Optional[tf.Tensor]): An optional bool mask of valid
|
||||
input elements (for max-len padded sequences (RNNs)).
|
||||
entropy_coeff (float): Coefficient of the entropy regularizer.
|
||||
clip_param (float): Clip parameter
|
||||
vf_clip_param (float): Clip parameter for the value function
|
||||
vf_loss_coeff (float): Coefficient of the value function loss
|
||||
use_gae (bool): If true, use the Generalized Advantage Estimator.
|
||||
"""
|
||||
if valid_mask is not None:
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
dist_class (Type[ActionDistribution]: The action distr. class.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
|
||||
|
||||
else:
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return tf.reduce_mean(t)
|
||||
|
||||
prev_dist = dist_class(prev_logits, model)
|
||||
# Make loss functions.
|
||||
logp_ratio = tf.exp(curr_action_dist.logp(actions) - prev_actions_logp)
|
||||
action_kl = prev_dist.kl(curr_action_dist)
|
||||
self.mean_kl = reduce_mean_valid(action_kl)
|
||||
|
||||
curr_entropy = curr_action_dist.entropy()
|
||||
self.mean_entropy = reduce_mean_valid(curr_entropy)
|
||||
|
||||
surrogate_loss = tf.minimum(
|
||||
advantages * logp_ratio,
|
||||
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
||||
1 + clip_param))
|
||||
self.mean_policy_loss = reduce_mean_valid(-surrogate_loss)
|
||||
|
||||
if use_gae:
|
||||
vf_loss1 = tf.math.square(value_fn - value_targets)
|
||||
vf_clipped = vf_preds + tf.clip_by_value(
|
||||
value_fn - vf_preds, -vf_clip_param, vf_clip_param)
|
||||
vf_loss2 = tf.math.square(vf_clipped - value_targets)
|
||||
vf_loss = tf.maximum(vf_loss1, vf_loss2)
|
||||
self.mean_vf_loss = reduce_mean_valid(vf_loss)
|
||||
loss = reduce_mean_valid(
|
||||
-surrogate_loss + cur_kl_coeff * action_kl +
|
||||
vf_loss_coeff * vf_loss - entropy_coeff * curr_entropy)
|
||||
else:
|
||||
self.mean_vf_loss = tf.constant(0.0)
|
||||
loss = reduce_mean_valid(-surrogate_loss +
|
||||
cur_kl_coeff * action_kl -
|
||||
entropy_coeff * curr_entropy)
|
||||
self.loss = loss
|
||||
|
||||
|
||||
def ppo_surrogate_loss(policy, model, dist_class, train_batch):
|
||||
Returns:
|
||||
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
||||
of loss tensors.
|
||||
"""
|
||||
logits, state = model.from_batch(train_batch)
|
||||
action_dist = dist_class(logits, model)
|
||||
curr_action_dist = dist_class(logits, model)
|
||||
|
||||
mask = None
|
||||
# RNN case: Mask away 0-padded chunks at end of time axis.
|
||||
if state:
|
||||
max_seq_len = tf.reduce_max(train_batch["seq_lens"])
|
||||
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
|
||||
policy.loss_obj = PPOLoss(
|
||||
dist_class,
|
||||
model,
|
||||
train_batch[Postprocessing.VALUE_TARGETS],
|
||||
train_batch[Postprocessing.ADVANTAGES],
|
||||
train_batch[SampleBatch.ACTIONS],
|
||||
train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
train_batch[SampleBatch.ACTION_LOGP],
|
||||
train_batch[SampleBatch.VF_PREDS],
|
||||
action_dist,
|
||||
model.value_function(),
|
||||
policy.kl_coeff,
|
||||
mask,
|
||||
entropy_coeff=policy.entropy_coeff,
|
||||
clip_param=policy.config["clip_param"],
|
||||
vf_clip_param=policy.config["vf_clip_param"],
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
use_gae=policy.config["use_gae"],
|
||||
)
|
||||
def reduce_mean_valid(t):
|
||||
return tf.reduce_mean(tf.boolean_mask(t, mask))
|
||||
|
||||
return policy.loss_obj.loss
|
||||
# non-RNN case: No masking.
|
||||
else:
|
||||
mask = None
|
||||
reduce_mean_valid = tf.reduce_mean
|
||||
|
||||
prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
model)
|
||||
|
||||
logp_ratio = tf.exp(
|
||||
curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
|
||||
train_batch[SampleBatch.ACTION_LOGP])
|
||||
action_kl = prev_action_dist.kl(curr_action_dist)
|
||||
mean_kl = reduce_mean_valid(action_kl)
|
||||
|
||||
curr_entropy = curr_action_dist.entropy()
|
||||
mean_entropy = reduce_mean_valid(curr_entropy)
|
||||
|
||||
surrogate_loss = tf.minimum(
|
||||
train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
|
||||
train_batch[Postprocessing.ADVANTAGES] * tf.clip_by_value(
|
||||
logp_ratio, 1 - policy.config["clip_param"],
|
||||
1 + policy.config["clip_param"]))
|
||||
mean_policy_loss = reduce_mean_valid(-surrogate_loss)
|
||||
|
||||
if policy.config["use_gae"]:
|
||||
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
|
||||
value_fn_out = model.value_function()
|
||||
vf_loss1 = tf.math.square(value_fn_out -
|
||||
train_batch[Postprocessing.VALUE_TARGETS])
|
||||
vf_clipped = prev_value_fn_out + tf.clip_by_value(
|
||||
value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
|
||||
policy.config["vf_clip_param"])
|
||||
vf_loss2 = tf.math.square(vf_clipped -
|
||||
train_batch[Postprocessing.VALUE_TARGETS])
|
||||
vf_loss = tf.maximum(vf_loss1, vf_loss2)
|
||||
mean_vf_loss = reduce_mean_valid(vf_loss)
|
||||
total_loss = reduce_mean_valid(
|
||||
-surrogate_loss + policy.kl_coeff * action_kl +
|
||||
policy.config["vf_loss_coeff"] * vf_loss -
|
||||
policy.entropy_coeff * curr_entropy)
|
||||
else:
|
||||
mean_vf_loss = tf.constant(0.0)
|
||||
total_loss = reduce_mean_valid(-surrogate_loss +
|
||||
policy.kl_coeff * action_kl -
|
||||
policy.entropy_coeff * curr_entropy)
|
||||
|
||||
# Store stats in policy for stats_fn.
|
||||
policy._total_loss = total_loss
|
||||
policy._mean_policy_loss = mean_policy_loss
|
||||
policy._mean_vf_loss = mean_vf_loss
|
||||
policy._mean_entropy = mean_entropy
|
||||
policy._mean_kl = mean_kl
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def kl_and_loss_stats(policy, train_batch):
|
||||
def kl_and_loss_stats(policy: Policy,
|
||||
train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
"""Stats function for PPO. Returns a dict with important KL and loss stats.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to generate stats for.
|
||||
train_batch (SampleBatch): The SampleBatch (already) used for training.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The stats dict.
|
||||
"""
|
||||
return {
|
||||
"cur_kl_coeff": tf.cast(policy.kl_coeff, tf.float64),
|
||||
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||
"total_loss": policy.loss_obj.loss,
|
||||
"policy_loss": policy.loss_obj.mean_policy_loss,
|
||||
"vf_loss": policy.loss_obj.mean_vf_loss,
|
||||
"total_loss": policy._total_loss,
|
||||
"policy_loss": policy._mean_policy_loss,
|
||||
"vf_loss": policy._mean_vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
train_batch[Postprocessing.VALUE_TARGETS],
|
||||
policy.model.value_function()),
|
||||
"kl": policy.loss_obj.mean_kl,
|
||||
"entropy": policy.loss_obj.mean_entropy,
|
||||
"kl": policy._mean_kl,
|
||||
"entropy": policy._mean_entropy,
|
||||
"entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64),
|
||||
}
|
||||
|
||||
|
||||
def vf_preds_fetches(policy):
|
||||
"""Adds value function outputs to experience train_batches."""
|
||||
def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]:
|
||||
"""Defines extra fetches per action computation.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to perform the extra action fetch on.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: Dict with extra tf fetches to perform per
|
||||
action computation.
|
||||
"""
|
||||
# Return value function outputs. VF estimates will hence be added to the
|
||||
# SampleBatches produced by the sampler(s) to generate the train batches
|
||||
# going into the loss function.
|
||||
return {
|
||||
SampleBatch.VF_PREDS: policy.model.value_function(),
|
||||
}
|
||||
|
||||
|
||||
def postprocess_ppo_gae(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
"""Adds the policy logits, VF preds, and advantages to the trajectory."""
|
||||
def postprocess_ppo_gae(
|
||||
policy: Policy,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
||||
"""Postprocesses a trajectory and returns the processed trajectory.
|
||||
|
||||
completed = sample_batch[SampleBatch.DONES][-1]
|
||||
if completed:
|
||||
The trajectory contains only data from one episode and from one agent.
|
||||
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
||||
contain a truncated (at-the-end) episode, in case the
|
||||
`config.rollout_fragment_length` was reached by the sampler.
|
||||
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
||||
exactly one episode (no matter how long).
|
||||
New columns can be added to sample_batch and existing ones may be altered.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy used to generate the trajectory
|
||||
(`sample_batch`)
|
||||
sample_batch (SampleBatch): The SampleBatch to postprocess.
|
||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||
same episode). NOTE: The other agents use the same policy.
|
||||
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
|
||||
object in which the agents operated.
|
||||
|
||||
Returns:
|
||||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||
"""
|
||||
|
||||
# Trajectory is actually complete -> last r=0.0.
|
||||
if sample_batch[SampleBatch.DONES][-1]:
|
||||
last_r = 0.0
|
||||
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(policy.num_state_tensors()):
|
||||
|
@ -179,6 +195,9 @@ def postprocess_ppo_gae(policy,
|
|||
sample_batch[SampleBatch.ACTIONS][-1],
|
||||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
|
||||
# Adds the policy logits, VF preds, and advantages to the batch,
|
||||
# using GAE ("generalized advantage estimation") or not.
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
|
@ -188,38 +207,81 @@ def postprocess_ppo_gae(policy,
|
|||
return batch
|
||||
|
||||
|
||||
def clip_gradients(policy, optimizer, loss):
|
||||
def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,
|
||||
loss: TensorType) -> ModelGradients:
|
||||
"""Gradients computing function (from loss tensor, using local optimizer).
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object that generated the loss tensor and
|
||||
that holds the given local optimizer.
|
||||
optimizer (LocalOptimizer): The tf (local) optimizer object to
|
||||
calculate the gradients with.
|
||||
loss (TensorType): The loss tensor for which gradients should be
|
||||
calculated.
|
||||
|
||||
Returns:
|
||||
ModelGradients: List of the possibly clipped gradients- and variable
|
||||
tuples.
|
||||
"""
|
||||
# Compute the gradients.
|
||||
variables = policy.model.trainable_variables()
|
||||
grads_and_vars = optimizer.compute_gradients(loss, variables)
|
||||
|
||||
# Clip by global norm, if necessary.
|
||||
if policy.config["grad_clip"] is not None:
|
||||
grads_and_vars = optimizer.compute_gradients(loss, variables)
|
||||
grads = [g for (g, v) in grads_and_vars]
|
||||
policy.grads, _ = tf.clip_by_global_norm(grads,
|
||||
policy.config["grad_clip"])
|
||||
clipped_grads = list(zip(policy.grads, variables))
|
||||
return clipped_grads
|
||||
clipped_grads_and_vars = list(zip(policy.grads, variables))
|
||||
return clipped_grads_and_vars
|
||||
else:
|
||||
return optimizer.compute_gradients(loss, variables)
|
||||
return grads_and_vars
|
||||
|
||||
|
||||
class KLCoeffMixin:
|
||||
"""Assigns the `update_kl()` method to the PPOPolicy.
|
||||
|
||||
This is used in PPO's execution plan (see ppo.py) for updating the KL
|
||||
coefficient after each learning step based on `config.kl_target` and
|
||||
the measured KL value (from the train_batch).
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
# KL Coefficient
|
||||
# The current KL value (as python float).
|
||||
self.kl_coeff_val = config["kl_coeff"]
|
||||
self.kl_target = config["kl_target"]
|
||||
# The current KL value (as tf Variable for in-graph operations).
|
||||
self.kl_coeff = get_variable(
|
||||
float(self.kl_coeff_val), tf_name="kl_coeff", trainable=False)
|
||||
# Constant target value.
|
||||
self.kl_target = config["kl_target"]
|
||||
|
||||
def update_kl(self, sampled_kl):
|
||||
# Update the current KL value based on the recently measured value.
|
||||
if sampled_kl > 2.0 * self.kl_target:
|
||||
self.kl_coeff_val *= 1.5
|
||||
elif sampled_kl < 0.5 * self.kl_target:
|
||||
self.kl_coeff_val *= 0.5
|
||||
|
||||
# Update the tf Variable (via session call).
|
||||
self.kl_coeff.load(self.kl_coeff_val, session=self.get_session())
|
||||
# Return the current KL value.
|
||||
return self.kl_coeff_val
|
||||
|
||||
|
||||
class ValueNetworkMixin:
|
||||
"""Assigns the `_value()` method to the PPOPolicy.
|
||||
|
||||
This way, Policy can call `_value()` to get the current VF estimate on a
|
||||
single(!) observation (as done in `postprocess_trajectory_fn`).
|
||||
Note: When doing this, an actual forward pass is being performed.
|
||||
This is different from only calling `model.value_function()`, where
|
||||
the result of the most recent forward pass is being used to return an
|
||||
already calculated tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
# When doing GAE, we need the value function estimate on the
|
||||
# observation.
|
||||
if config["use_gae"]:
|
||||
|
||||
@make_tf_callable(self.get_session())
|
||||
|
@ -233,8 +295,10 @@ class ValueNetworkMixin:
|
|||
"is_training": tf.convert_to_tensor([False]),
|
||||
}, [tf.convert_to_tensor([s]) for s in state],
|
||||
tf.convert_to_tensor([1]))
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0]
|
||||
|
||||
# When not doing GAE, we do not require the value function's output.
|
||||
else:
|
||||
|
||||
@make_tf_callable(self.get_session())
|
||||
|
@ -244,12 +308,32 @@ class ValueNetworkMixin:
|
|||
self._value = value
|
||||
|
||||
|
||||
def setup_config(policy, obs_space, action_space, config):
|
||||
# auto set the model option for layer sharing
|
||||
def setup_config(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Executed before Policy is "initialized" (at beginning of constructor).
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
# Auto set the model option for VF layer sharing.
|
||||
config["model"]["vf_share_layers"] = config["vf_share_layers"]
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Call all mixin classes' constructors before PPOPolicy initialization.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
KLCoeffMixin.__init__(policy, config)
|
||||
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
|
||||
|
@ -257,14 +341,16 @@ def setup_mixins(policy, obs_space, action_space, config):
|
|||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
|
||||
# above.
|
||||
PPOTFPolicy = build_tf_policy(
|
||||
name="PPOTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
|
||||
loss_fn=ppo_surrogate_loss,
|
||||
stats_fn=kl_and_loss_stats,
|
||||
extra_action_fetches_fn=vf_preds_fetches,
|
||||
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
|
||||
postprocess_fn=postprocess_ppo_gae,
|
||||
gradients_fn=clip_gradients,
|
||||
stats_fn=kl_and_loss_stats,
|
||||
gradients_fn=compute_and_clip_gradients,
|
||||
extra_action_fetches_fn=vf_preds_fetches,
|
||||
before_init=setup_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
"""
|
||||
PyTorch policy class used for PPO.
|
||||
"""
|
||||
import gym
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
||||
setup_config
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
|
||||
LearningRateSchedule
|
||||
|
@ -14,107 +22,33 @@ from ray.rllib.policy.view_requirement import ViewRequirement
|
|||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
|
||||
explained_variance, sequence_mask
|
||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PPOLoss:
|
||||
def __init__(self,
|
||||
dist_class,
|
||||
model,
|
||||
value_targets,
|
||||
advantages,
|
||||
actions,
|
||||
prev_logits,
|
||||
prev_actions_logp,
|
||||
vf_preds,
|
||||
curr_action_dist,
|
||||
value_fn,
|
||||
cur_kl_coeff,
|
||||
valid_mask,
|
||||
entropy_coeff=0,
|
||||
clip_param=0.1,
|
||||
vf_clip_param=0.1,
|
||||
vf_loss_coeff=1.0,
|
||||
use_gae=True):
|
||||
"""Constructs the loss for Proximal Policy Objective.
|
||||
def ppo_surrogate_loss(
|
||||
policy: Policy, model: ModelV2,
|
||||
dist_class: Type[TorchDistributionWrapper],
|
||||
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
|
||||
"""Constructs the loss for Proximal Policy Objective.
|
||||
|
||||
Arguments:
|
||||
dist_class: action distribution class for logits.
|
||||
value_targets (Placeholder): Placeholder for target values; used
|
||||
for GAE.
|
||||
actions (Placeholder): Placeholder for actions taken
|
||||
from previous model evaluation.
|
||||
advantages (Placeholder): Placeholder for calculated advantages
|
||||
from previous model evaluation.
|
||||
prev_logits (Placeholder): Placeholder for logits output from
|
||||
previous model evaluation.
|
||||
prev_actions_logp (Placeholder): Placeholder for prob output from
|
||||
previous model evaluation.
|
||||
vf_preds (Placeholder): Placeholder for value function output
|
||||
from previous model evaluation.
|
||||
curr_action_dist (ActionDistribution): ActionDistribution
|
||||
of the current model.
|
||||
value_fn (Tensor): Current value function output Tensor.
|
||||
cur_kl_coeff (Variable): Variable holding the current PPO KL
|
||||
coefficient.
|
||||
valid_mask (Tensor): A bool mask of valid input elements (#2992).
|
||||
entropy_coeff (float): Coefficient of the entropy regularizer.
|
||||
clip_param (float): Clip parameter
|
||||
vf_clip_param (float): Clip parameter for the value function
|
||||
vf_loss_coeff (float): Coefficient of the value function loss
|
||||
use_gae (bool): If true, use the Generalized Advantage Estimator.
|
||||
"""
|
||||
if valid_mask is not None:
|
||||
num_valid = torch.sum(valid_mask)
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
dist_class (Type[ActionDistribution]: The action distr. class.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return torch.sum(t[valid_mask]) / num_valid
|
||||
|
||||
else:
|
||||
reduce_mean_valid = torch.mean
|
||||
|
||||
prev_dist = dist_class(prev_logits, model)
|
||||
# Make loss functions.
|
||||
logp_ratio = torch.exp(
|
||||
curr_action_dist.logp(actions) - prev_actions_logp)
|
||||
action_kl = prev_dist.kl(curr_action_dist)
|
||||
self.mean_kl = reduce_mean_valid(action_kl)
|
||||
|
||||
curr_entropy = curr_action_dist.entropy()
|
||||
self.mean_entropy = reduce_mean_valid(curr_entropy)
|
||||
|
||||
surrogate_loss = torch.min(
|
||||
advantages * logp_ratio,
|
||||
advantages * torch.clamp(logp_ratio, 1 - clip_param,
|
||||
1 + clip_param))
|
||||
self.mean_policy_loss = reduce_mean_valid(-surrogate_loss)
|
||||
|
||||
if use_gae:
|
||||
vf_loss1 = torch.pow(value_fn - value_targets, 2.0)
|
||||
vf_clipped = vf_preds + torch.clamp(value_fn - vf_preds,
|
||||
-vf_clip_param, vf_clip_param)
|
||||
vf_loss2 = torch.pow(vf_clipped - value_targets, 2.0)
|
||||
vf_loss = torch.max(vf_loss1, vf_loss2)
|
||||
self.mean_vf_loss = reduce_mean_valid(vf_loss)
|
||||
loss = reduce_mean_valid(
|
||||
-surrogate_loss + cur_kl_coeff * action_kl +
|
||||
vf_loss_coeff * vf_loss - entropy_coeff * curr_entropy)
|
||||
else:
|
||||
self.mean_vf_loss = 0.0
|
||||
loss = reduce_mean_valid(-surrogate_loss +
|
||||
cur_kl_coeff * action_kl -
|
||||
entropy_coeff * curr_entropy)
|
||||
self.loss = loss
|
||||
|
||||
|
||||
def ppo_surrogate_loss(policy, model, dist_class, train_batch):
|
||||
Returns:
|
||||
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
||||
of loss tensors.
|
||||
"""
|
||||
logits, state = model.from_batch(train_batch, is_training=True)
|
||||
action_dist = dist_class(logits, model)
|
||||
curr_action_dist = dist_class(logits, model)
|
||||
|
||||
mask = None
|
||||
# RNN case: Mask away 0-padded chunks at end of time axis.
|
||||
if state:
|
||||
max_seq_len = torch.max(train_batch["seq_lens"])
|
||||
mask = sequence_mask(
|
||||
|
@ -122,69 +56,160 @@ def ppo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
max_seq_len,
|
||||
time_major=model.is_time_major())
|
||||
mask = torch.reshape(mask, [-1])
|
||||
num_valid = torch.sum(mask)
|
||||
|
||||
policy.loss_obj = PPOLoss(
|
||||
dist_class,
|
||||
model,
|
||||
train_batch[Postprocessing.VALUE_TARGETS],
|
||||
train_batch[Postprocessing.ADVANTAGES],
|
||||
train_batch[SampleBatch.ACTIONS],
|
||||
train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
train_batch[SampleBatch.ACTION_LOGP],
|
||||
train_batch[SampleBatch.VF_PREDS],
|
||||
action_dist,
|
||||
model.value_function(),
|
||||
policy.kl_coeff,
|
||||
mask,
|
||||
entropy_coeff=policy.entropy_coeff,
|
||||
clip_param=policy.config["clip_param"],
|
||||
vf_clip_param=policy.config["vf_clip_param"],
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
use_gae=policy.config["use_gae"],
|
||||
)
|
||||
def reduce_mean_valid(t):
|
||||
return torch.sum(t[mask]) / num_valid
|
||||
|
||||
return policy.loss_obj.loss
|
||||
# non-RNN case: No masking.
|
||||
else:
|
||||
mask = None
|
||||
reduce_mean_valid = torch.mean
|
||||
|
||||
prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
model)
|
||||
|
||||
logp_ratio = torch.exp(
|
||||
curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
|
||||
train_batch[SampleBatch.ACTION_LOGP])
|
||||
action_kl = prev_action_dist.kl(curr_action_dist)
|
||||
mean_kl = reduce_mean_valid(action_kl)
|
||||
|
||||
curr_entropy = curr_action_dist.entropy()
|
||||
mean_entropy = reduce_mean_valid(curr_entropy)
|
||||
|
||||
surrogate_loss = torch.min(
|
||||
train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
|
||||
train_batch[Postprocessing.ADVANTAGES] * torch.clamp(
|
||||
logp_ratio, 1 - policy.config["clip_param"],
|
||||
1 + policy.config["clip_param"]))
|
||||
mean_policy_loss = reduce_mean_valid(-surrogate_loss)
|
||||
|
||||
if policy.config["use_gae"]:
|
||||
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
|
||||
value_fn_out = model.value_function()
|
||||
vf_loss1 = torch.pow(
|
||||
value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
|
||||
vf_clipped = prev_value_fn_out + torch.clamp(
|
||||
value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
|
||||
policy.config["vf_clip_param"])
|
||||
vf_loss2 = torch.pow(
|
||||
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
|
||||
vf_loss = torch.max(vf_loss1, vf_loss2)
|
||||
mean_vf_loss = reduce_mean_valid(vf_loss)
|
||||
total_loss = reduce_mean_valid(
|
||||
-surrogate_loss + policy.kl_coeff * action_kl +
|
||||
policy.config["vf_loss_coeff"] * vf_loss -
|
||||
policy.entropy_coeff * curr_entropy)
|
||||
else:
|
||||
mean_vf_loss = 0.0
|
||||
total_loss = reduce_mean_valid(-surrogate_loss +
|
||||
policy.kl_coeff * action_kl -
|
||||
policy.entropy_coeff * curr_entropy)
|
||||
|
||||
# Store stats in policy for stats_fn.
|
||||
policy._total_loss = total_loss
|
||||
policy._mean_policy_loss = mean_policy_loss
|
||||
policy._mean_vf_loss = mean_vf_loss
|
||||
policy._mean_entropy = mean_entropy
|
||||
policy._mean_kl = mean_kl
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def kl_and_loss_stats(policy, train_batch):
|
||||
def kl_and_loss_stats(policy: Policy,
|
||||
train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
"""Stats function for PPO. Returns a dict with important KL and loss stats.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to generate stats for.
|
||||
train_batch (SampleBatch): The SampleBatch (already) used for training.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The stats dict.
|
||||
"""
|
||||
return {
|
||||
"cur_kl_coeff": policy.kl_coeff,
|
||||
"cur_lr": policy.cur_lr,
|
||||
"total_loss": policy.loss_obj.loss,
|
||||
"policy_loss": policy.loss_obj.mean_policy_loss,
|
||||
"vf_loss": policy.loss_obj.mean_vf_loss,
|
||||
"total_loss": policy._total_loss,
|
||||
"policy_loss": policy._mean_policy_loss,
|
||||
"vf_loss": policy._mean_vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
train_batch[Postprocessing.VALUE_TARGETS],
|
||||
policy.model.value_function()),
|
||||
"kl": policy.loss_obj.mean_kl,
|
||||
"entropy": policy.loss_obj.mean_entropy,
|
||||
"kl": policy._mean_kl,
|
||||
"entropy": policy._mean_entropy,
|
||||
"entropy_coeff": policy.entropy_coeff,
|
||||
}
|
||||
|
||||
|
||||
def vf_preds_fetches(policy, input_dict, state_batches, model, action_dist):
|
||||
"""Adds value function outputs to experience train_batches."""
|
||||
def vf_preds_fetches(
|
||||
policy: Policy, input_dict: Dict[str, TensorType],
|
||||
state_batches: List[TensorType], model: ModelV2,
|
||||
action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
|
||||
"""Defines extra fetches per action computation.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to perform the extra action fetch on.
|
||||
input_dict (Dict[str, TensorType]): The input dict used for the action
|
||||
computing forward pass.
|
||||
state_batches (List[TensorType]): List of state tensors (empty for
|
||||
non-RNNs).
|
||||
model (ModelV2): The Model object of the Policy.
|
||||
action_dist (TorchDistributionWrapper): The instantiated distribution
|
||||
object, resulting from the model's outputs and the given
|
||||
distribution class.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: Dict with extra tf fetches to perform per
|
||||
action computation.
|
||||
"""
|
||||
# Return value function outputs. VF estimates will hence be added to the
|
||||
# SampleBatches produced by the sampler(s) to generate the train batches
|
||||
# going into the loss function.
|
||||
return {
|
||||
SampleBatch.VF_PREDS: policy.model.value_function(),
|
||||
}
|
||||
|
||||
|
||||
class KLCoeffMixin:
|
||||
"""Assigns the `update_kl()` method to the PPOPolicy.
|
||||
|
||||
This is used in PPO's execution plan (see ppo.py) for updating the KL
|
||||
coefficient after each learning step based on `config.kl_target` and
|
||||
the measured KL value (from the train_batch).
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
# KL Coefficient.
|
||||
# The current KL value (as python float).
|
||||
self.kl_coeff = config["kl_coeff"]
|
||||
# Constant target value.
|
||||
self.kl_target = config["kl_target"]
|
||||
|
||||
def update_kl(self, sampled_kl):
|
||||
# Update the current KL value based on the recently measured value.
|
||||
if sampled_kl > 2.0 * self.kl_target:
|
||||
self.kl_coeff *= 1.5
|
||||
elif sampled_kl < 0.5 * self.kl_target:
|
||||
self.kl_coeff *= 0.5
|
||||
# Return the current KL value.
|
||||
return self.kl_coeff
|
||||
|
||||
|
||||
class ValueNetworkMixin:
|
||||
"""Assigns the `_value()` method to the PPOPolicy.
|
||||
|
||||
This way, Policy can call `_value()` to get the current VF estimate on a
|
||||
single(!) observation (as done in `postprocess_trajectory_fn`).
|
||||
Note: When doing this, an actual forward pass is being performed.
|
||||
This is different from only calling `model.value_function()`, where
|
||||
the result of the most recent forward pass is being used to return an
|
||||
already calculated tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
# When doing GAE, we need the value function estimate on the
|
||||
# observation.
|
||||
if config["use_gae"]:
|
||||
|
||||
def value(ob, prev_action, prev_reward, *state):
|
||||
|
@ -200,8 +225,10 @@ class ValueNetworkMixin:
|
|||
convert_to_torch_tensor(np.asarray([s]), self.device)
|
||||
for s in state
|
||||
], convert_to_torch_tensor(np.asarray([1]), self.device))
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0]
|
||||
|
||||
# When not doing GAE, we do not require the value function's output.
|
||||
else:
|
||||
|
||||
def value(ob, prev_action, prev_reward, *state):
|
||||
|
@ -210,7 +237,17 @@ class ValueNetworkMixin:
|
|||
self._value = value
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Call all mixin classes' constructors before PPOPolicy initialization.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
KLCoeffMixin.__init__(policy, config)
|
||||
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
|
||||
|
@ -218,7 +255,20 @@ def setup_mixins(policy, obs_space, action_space, config):
|
|||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
def training_view_requirements_fn(policy):
|
||||
def training_view_requirements_fn(
|
||||
policy: Policy) -> Dict[str, ViewRequirement]:
|
||||
"""Function defining the view requirements for training the policy.
|
||||
|
||||
These go on top of the Policy's Model's own view requirements used for
|
||||
action computing forward passes.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy that requires the returned
|
||||
ViewRequirements.
|
||||
|
||||
Returns:
|
||||
Dict[str, ViewRequirement]: The Policy's view requirements.
|
||||
"""
|
||||
return {
|
||||
# Next obs are needed for PPO postprocessing.
|
||||
SampleBatch.NEXT_OBS: ViewRequirement(SampleBatch.OBS, shift=1),
|
||||
|
@ -233,6 +283,8 @@ def training_view_requirements_fn(policy):
|
|||
}
|
||||
|
||||
|
||||
# Build a child class of `TorchPolicy`, given the custom functions defined
|
||||
# above.
|
||||
PPOTorchPolicy = build_torch_policy(
|
||||
name="PPOTorchPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
|
||||
|
|
|
@ -21,12 +21,13 @@ class TestAPPO(unittest.TestCase):
|
|||
config["num_workers"] = 1
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
||||
for _ in framework_iterator(config):
|
||||
_config = config.copy()
|
||||
trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
|
||||
_config = config.copy()
|
||||
_config["vtrace"] = True
|
||||
|
@ -34,6 +35,7 @@ class TestAPPO(unittest.TestCase):
|
|||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -268,9 +268,11 @@ class TestPPO(unittest.TestCase):
|
|||
policy_sess = policy.get_session()
|
||||
k, e, pl, v, tl = policy_sess.run(
|
||||
[
|
||||
policy.loss_obj.mean_kl, policy.loss_obj.mean_entropy,
|
||||
policy.loss_obj.mean_policy_loss,
|
||||
policy.loss_obj.mean_vf_loss, policy.loss_obj.loss
|
||||
policy._mean_kl,
|
||||
policy._mean_entropy,
|
||||
policy._mean_policy_loss,
|
||||
policy._mean_vf_loss,
|
||||
policy._total_loss,
|
||||
],
|
||||
feed_dict=policy._get_loss_inputs_dict(
|
||||
train_batch, shuffle=False))
|
||||
|
@ -280,12 +282,11 @@ class TestPPO(unittest.TestCase):
|
|||
check(v, np.mean(vf_loss), decimals=4)
|
||||
check(tl, overall_loss, decimals=4)
|
||||
else:
|
||||
check(policy.loss_obj.mean_kl, kl)
|
||||
check(policy.loss_obj.mean_entropy, entropy)
|
||||
check(policy.loss_obj.mean_policy_loss, np.mean(-pg_loss))
|
||||
check(
|
||||
policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
|
||||
check(policy.loss_obj.loss, overall_loss, decimals=4)
|
||||
check(policy._mean_kl, kl)
|
||||
check(policy._mean_entropy, entropy)
|
||||
check(policy._mean_policy_loss, np.mean(-pg_loss))
|
||||
check(policy._mean_vf_loss, np.mean(vf_loss), decimals=4)
|
||||
check(policy._total_loss, overall_loss, decimals=4)
|
||||
trainer.stop()
|
||||
|
||||
def _ppo_loss_helper(self,
|
||||
|
|
|
@ -35,8 +35,8 @@ def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
|||
def build_trainer(
|
||||
name: str,
|
||||
*,
|
||||
default_config: TrainerConfigDict = None,
|
||||
validate_config: Callable[[TrainerConfigDict], None] = None,
|
||||
default_config: Optional[TrainerConfigDict] = None,
|
||||
validate_config: Optional[Callable[[TrainerConfigDict], None]] = None,
|
||||
default_policy: Optional[Type[Policy]] = None,
|
||||
get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[
|
||||
Policy]]]] = None,
|
||||
|
@ -46,7 +46,7 @@ def build_trainer(
|
|||
mixins: Optional[List[type]] = None,
|
||||
execution_plan: Optional[Callable[[
|
||||
WorkerSet, TrainerConfigDict
|
||||
], Iterable[ResultDict]]] = default_execution_plan):
|
||||
], Iterable[ResultDict]]] = default_execution_plan) -> Type[Trainer]:
|
||||
"""Helper function for defining a custom trainer.
|
||||
|
||||
Functions will be run in this order to initialize the trainer:
|
||||
|
@ -56,11 +56,11 @@ def build_trainer(
|
|||
|
||||
Args:
|
||||
name (str): name of the trainer (e.g., "PPO")
|
||||
default_config (TrainerConfigDict): The default config dict
|
||||
default_config (Optional[TrainerConfigDict]): The default config dict
|
||||
of the algorithm, otherwise uses the Trainer default config.
|
||||
validate_config (Optional[callable]): Optional callable that takes the
|
||||
config to check for correctness. It may mutate the config as
|
||||
needed.
|
||||
validate_config (Optional[Callable[[TrainerConfigDict], None]]):
|
||||
Optional callable that takes the config to check for correctness.
|
||||
It may mutate the config as needed.
|
||||
default_policy (Optional[Type[Policy]]): The default Policy class to
|
||||
use.
|
||||
get_policy_class (Optional[Callable[
|
||||
|
@ -81,10 +81,12 @@ def build_trainer(
|
|||
mixins (list): list of any class mixins for the returned trainer class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the Trainer class.
|
||||
execution_plan (func): Setup the distributed execution workflow.
|
||||
execution_plan (Optional[Callable[[WorkerSet, TrainerConfigDict],
|
||||
Iterable[ResultDict]]]): Optional callable that sets up the
|
||||
distributed execution workflow.
|
||||
|
||||
Returns:
|
||||
a Trainer instance that uses the specified args.
|
||||
Type[Trainer]: A Trainer sub-class configured by the specified args.
|
||||
"""
|
||||
|
||||
original_kwargs = locals().copy()
|
||||
|
|
|
@ -21,9 +21,9 @@ import ray
|
|||
from ray import tune
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy, KLCoeffMixin, \
|
||||
PPOLoss as TFLoss
|
||||
ppo_surrogate_loss as tf_loss
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy, \
|
||||
KLCoeffMixin as TorchKLCoeffMixin, PPOLoss as TorchLoss
|
||||
KLCoeffMixin as TorchKLCoeffMixin, ppo_surrogate_loss as torch_loss
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
|
@ -119,42 +119,22 @@ def centralized_critic_postprocessing(policy,
|
|||
return train_batch
|
||||
|
||||
|
||||
# Copied from PPO but optimizing the central value function
|
||||
# Copied from PPO but optimizing the central value function.
|
||||
def loss_with_central_critic(policy, model, dist_class, train_batch):
|
||||
CentralizedValueMixin.__init__(policy)
|
||||
func = tf_loss if not policy.config["framework"] == "torch" else torch_loss
|
||||
|
||||
logits, state = model.from_batch(train_batch)
|
||||
action_dist = dist_class(logits, model)
|
||||
policy.central_value_out = policy.model.central_value_function(
|
||||
vf_saved = model.value_function
|
||||
model.value_function = lambda: policy.model.central_value_function(
|
||||
train_batch[SampleBatch.CUR_OBS], train_batch[OPPONENT_OBS],
|
||||
train_batch[OPPONENT_ACTION])
|
||||
|
||||
func = TFLoss if not policy.config["framework"] == "torch" else TorchLoss
|
||||
adv = tf.ones_like(train_batch[Postprocessing.ADVANTAGES], dtype=tf.bool) \
|
||||
if policy.config["framework"] != "torch" else \
|
||||
torch.ones_like(train_batch[Postprocessing.ADVANTAGES],
|
||||
dtype=torch.bool)
|
||||
policy._central_value_out = model.value_function()
|
||||
loss = func(policy, model, dist_class, train_batch)
|
||||
|
||||
policy.loss_obj = func(
|
||||
dist_class,
|
||||
model,
|
||||
train_batch[Postprocessing.VALUE_TARGETS],
|
||||
train_batch[Postprocessing.ADVANTAGES],
|
||||
train_batch[SampleBatch.ACTIONS],
|
||||
train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
train_batch[SampleBatch.ACTION_LOGP],
|
||||
train_batch[SampleBatch.VF_PREDS],
|
||||
action_dist,
|
||||
policy.central_value_out,
|
||||
policy.kl_coeff,
|
||||
adv,
|
||||
entropy_coeff=policy.entropy_coeff,
|
||||
clip_param=policy.config["clip_param"],
|
||||
vf_clip_param=policy.config["vf_clip_param"],
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
use_gae=policy.config["use_gae"])
|
||||
model.value_function = vf_saved
|
||||
|
||||
return policy.loss_obj.loss
|
||||
return loss
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
|
@ -170,7 +150,7 @@ def central_vf_stats(policy, train_batch, grads):
|
|||
return {
|
||||
"vf_explained_var": explained_variance(
|
||||
train_batch[Postprocessing.VALUE_TARGETS],
|
||||
policy.central_value_out),
|
||||
policy._central_value_out),
|
||||
}
|
||||
|
||||
|
||||
|
@ -197,8 +177,8 @@ CCPPOTorchPolicy = PPOTorchPolicy.with_updates(
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
return CCPPOTorchPolicy if config["framework"] == "torch" \
|
||||
else CCPPOTFPolicy
|
||||
if config["framework"] == "torch":
|
||||
return CCPPOTorchPolicy
|
||||
|
||||
|
||||
CCTrainer = PPOTrainer.with_updates(
|
||||
|
|
|
@ -25,7 +25,7 @@ def StandardMetricsReporting(
|
|||
to collect metrics from.
|
||||
|
||||
Returns:
|
||||
A local iterator over training results.
|
||||
LocalIterator[dict]: A local iterator over training results.
|
||||
|
||||
Examples:
|
||||
>>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
|
||||
|
|
|
@ -152,7 +152,7 @@ class TestDistributions(unittest.TestCase):
|
|||
|
||||
def test_squashed_gaussian(self):
|
||||
"""Tests the SquashedGaussian ActionDistribution for all frameworks."""
|
||||
input_space = Box(-2.0, 2.0, shape=(200, 10))
|
||||
input_space = Box(-2.0, 2.0, shape=(2000, 10))
|
||||
low, high = -2.0, 1.0
|
||||
|
||||
for fw, sess in framework_iterator(
|
||||
|
@ -245,7 +245,7 @@ class TestDistributions(unittest.TestCase):
|
|||
|
||||
def test_diag_gaussian(self):
|
||||
"""Tests the DiagGaussian ActionDistribution for all frameworks."""
|
||||
input_space = Box(-2.0, 2.0, shape=(200, 10))
|
||||
input_space = Box(-2.0, 2.0, shape=(2000, 10))
|
||||
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=("torch", "tf", "tfe"), session=True):
|
||||
|
@ -310,9 +310,9 @@ class TestDistributions(unittest.TestCase):
|
|||
check(outs, log_prob, decimals=4)
|
||||
|
||||
def test_beta(self):
|
||||
input_space = Box(-2.0, 1.0, shape=(200, 10))
|
||||
input_space = Box(-2.0, 1.0, shape=(2000, 10))
|
||||
low, high = -1.0, 2.0
|
||||
plain_beta_value_space = Box(0.0, 1.0, shape=(200, 5))
|
||||
plain_beta_value_space = Box(0.0, 1.0, shape=(2000, 5))
|
||||
|
||||
for fw, sess in framework_iterator(session=True):
|
||||
cls = TorchBeta if fw == "torch" else Beta
|
||||
|
@ -361,7 +361,7 @@ class TestDistributions(unittest.TestCase):
|
|||
check(
|
||||
out,
|
||||
np.sum(np.log(beta.pdf(values, alpha, beta_)), -1),
|
||||
rtol=0.001)
|
||||
rtol=0.01)
|
||||
|
||||
# TODO(sven): Test entropy outputs (against scipy).
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.typing import ModelGradients, TensorType, \
|
||||
from ray.rllib.utils.typing import AgentID, ModelGradients, TensorType, \
|
||||
TrainerConfigDict
|
||||
|
||||
|
||||
|
@ -24,8 +24,8 @@ def build_tf_policy(
|
|||
get_default_config: Optional[Callable[[None],
|
||||
TrainerConfigDict]] = None,
|
||||
postprocess_fn: Optional[Callable[[
|
||||
Policy, SampleBatch, Optional[List[SampleBatch]], Optional[
|
||||
"MultiAgentEpisode"]
|
||||
Policy, SampleBatch, Optional[Dict[AgentID, SampleBatch]],
|
||||
Optional["MultiAgentEpisode"]
|
||||
], SampleBatch]] = None,
|
||||
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
|
||||
str, TensorType]]] = None,
|
||||
|
@ -63,7 +63,7 @@ def build_tf_policy(
|
|||
], Tuple[TensorType, type, List[TensorType]]]] = None,
|
||||
mixins: Optional[List[type]] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
|
||||
obs_include_prev_action_reward: bool = True):
|
||||
obs_include_prev_action_reward: bool = True) -> Type[TFPolicy]:
|
||||
"""Helper function for creating a dynamic tf policy at runtime.
|
||||
|
||||
Functions will be run in this order to initialize the policy:
|
||||
|
@ -94,9 +94,9 @@ def build_tf_policy(
|
|||
overrides. If None, uses only(!) the user-provided
|
||||
PartialTrainerConfigDict as dict for this Policy.
|
||||
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
|
||||
List[SampleBatch], MultiAgentEpisode], None]]): Optional callable
|
||||
for post-processing experience batches (called after the
|
||||
super's `postprocess_trajectory` method).
|
||||
Optional[Dict[AgentID, SampleBatch]], MultiAgentEpisode], None]]):
|
||||
Optional callable for post-processing experience batches (called
|
||||
after the parent class' `postprocess_trajectory` method).
|
||||
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
||||
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
||||
TF tensors to fetch given the policy and batch input tensors. If
|
||||
|
@ -172,7 +172,8 @@ def build_tf_policy(
|
|||
previous action and reward in the model input.
|
||||
|
||||
Returns:
|
||||
a DynamicTFPolicy instance that uses the specified args
|
||||
Type[DynamicTFPolicy]: A child class of DynamicTFPolicy based on the
|
||||
specified args.
|
||||
"""
|
||||
original_kwargs = locals().copy()
|
||||
base = add_mixins(DynamicTFPolicy, mixins)
|
||||
|
|
|
@ -261,13 +261,14 @@ class JsonIOTest(unittest.TestCase):
|
|||
for _ in range(100):
|
||||
writer.write(SAMPLES)
|
||||
num_files = len(os.listdir(self.test_dir))
|
||||
# Magic numbers: 2: On travis, it seems to create only 2 files,
|
||||
# but sometimes also 7.
|
||||
# 12 or 13: Mac locally.
|
||||
|
||||
# Pagination can't really be predicted:
|
||||
# On travis, it seems to create only 2 files, but sometimes also
|
||||
# 6, or 7. 12 or 13 usually on a Mac locally.
|
||||
# Reasons: Different compressions, file-size interpretations,
|
||||
# json writers?
|
||||
assert num_files in [2, 7, 12, 13], \
|
||||
"Expected 2|7|12|13 files, but found {} ({})". \
|
||||
# json writers?
|
||||
assert num_files >= 2, \
|
||||
"Expected >= 2 files, but found {} ({})". \
|
||||
format(num_files, os.listdir(self.test_dir))
|
||||
|
||||
def test_read_write(self):
|
||||
|
|
|
@ -2,6 +2,7 @@ from gym.spaces import Box, Discrete
|
|||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
|
@ -257,7 +258,6 @@ class TestRolloutLearntPolicy(unittest.TestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
# One can specify the specific TestCase class to run.
|
||||
|
|
|
@ -15,10 +15,7 @@ from ray.rllib.utils.test_utils import framework_iterator
|
|||
ACTION_SPACES_TO_TEST = {
|
||||
"discrete": Discrete(5),
|
||||
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
||||
"vector2": Box(-1.0, 1.0, (
|
||||
5,
|
||||
5,
|
||||
), dtype=np.float32),
|
||||
# "vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
|
||||
"multidiscrete": MultiDiscrete([1, 2, 3, 4]),
|
||||
"tuple": Tuple(
|
||||
[Discrete(2),
|
||||
|
@ -91,15 +88,19 @@ def check_support(alg, config, train=True, check_bounds=False, tfe=False):
|
|||
a.stop()
|
||||
print(stat)
|
||||
|
||||
frameworks = ("torch", "tf")
|
||||
frameworks = ("tf", "torch")
|
||||
if tfe:
|
||||
frameworks += ("tfe", )
|
||||
for _ in framework_iterator(config, frameworks=frameworks):
|
||||
# Check all action spaces (using a discrete obs-space).
|
||||
for a_name, action_space in ACTION_SPACES_TO_TEST.items():
|
||||
for a_name in ACTION_SPACES_TO_TEST.keys():
|
||||
_do_check(alg, config, a_name, "discrete")
|
||||
# Check all obs spaces (using a supported action-space).
|
||||
for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items():
|
||||
for o_name in OBSERVATION_SPACES_TO_TEST.keys():
|
||||
# We already tested discrete observation spaces against all action
|
||||
# spaces above -> skip.
|
||||
if o_name == "discrete":
|
||||
continue
|
||||
a_name = "discrete" if alg not in ["DDPG", "SAC"] else "vector"
|
||||
_do_check(alg, config, a_name, o_name)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue