Revert "Revert [RLlib] POC: Deprecate build_policy (policy template) for torch only; PPOTorchPolicy (#20061) (#20399)" (#20417)

This reverts commit 90dc5460d4.
This commit is contained in:
Sven Mika 2021-11-16 14:49:41 +01:00 committed by GitHub
parent 6504ad6bb2
commit f82880eda1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 485 additions and 345 deletions

View file

@ -117,6 +117,7 @@ class Trainable:
self._stderr_file = stderr_file
start_time = time.time()
self._local_ip = self.get_current_ip()
self.setup(copy.deepcopy(self.config))
setup_time = time.time() - start_time
if setup_time > SETUP_TIME_THRESHOLD:
@ -124,7 +125,6 @@ class Trainable:
"trainable is slow to initialize, consider setting "
"reuse_actors=True to reduce actor creation "
"overheads.".format(setup_time))
self._local_ip = self.get_current_ip()
log_sys_usage = self.config.get("log_sys_usage", False)
self._monitor = UtilMonitor(start=log_sys_usage)

View file

@ -1654,7 +1654,7 @@ py_test(
tags = ["team:ml", "examples", "examples_A"],
size = "medium",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=60"]
args = ["--as-test", "--stop-reward=70"]
)
py_test(
@ -1663,7 +1663,7 @@ py_test(
tags = ["team:ml", "examples", "examples_A"],
size = "medium",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=60", "--framework torch"]
args = ["--as-test", "--stop-reward=70", "--framework torch"]
)
py_test(

View file

@ -1,13 +1,13 @@
import gym
from typing import Optional, Dict
from typing import Dict, List, Optional
import ray
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
@ -97,11 +97,33 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
}
def model_value_predictions(
policy: Policy, input_dict: Dict[str, TensorType], state_batches,
model: ModelV2,
action_dist: ActionDistribution) -> Dict[str, TensorType]:
return {SampleBatch.VF_PREDS: model.value_function()}
def vf_preds_fetches(
policy: Policy, input_dict: Dict[str, TensorType],
state_batches: List[TensorType], model: ModelV2,
action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
"""Defines extra fetches per action computation.
Args:
policy (Policy): The Policy to perform the extra action fetch on.
input_dict (Dict[str, TensorType]): The input dict used for the action
computing forward pass.
state_batches (List[TensorType]): List of state tensors (empty for
non-RNNs).
model (ModelV2): The Model object of the Policy.
action_dist (TorchDistributionWrapper): The instantiated distribution
object, resulting from the model's outputs and the given
distribution class.
Returns:
Dict[str, TensorType]: Dict with extra tf fetches to perform per
action computation.
"""
# Return value function outputs. VF estimates will hence be added to the
# SampleBatches produced by the sampler(s) to generate the train batches
# going into the loss function.
return {
SampleBatch.VF_PREDS: model.value_function(),
}
def torch_optimizer(policy: Policy,
@ -109,6 +131,41 @@ def torch_optimizer(policy: Policy,
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
class ValueNetworkMixin:
"""Assigns the `_value()` method to the PPOPolicy.
This way, Policy can call `_value()` to get the current VF estimate on a
single(!) observation (as done in `postprocess_trajectory_fn`).
Note: When doing this, an actual forward pass is being performed.
This is different from only calling `model.value_function()`, where
the result of the most recent forward pass is being used to return an
already calculated tensor.
"""
def __init__(self, obs_space, action_space, config):
# When doing GAE, we need the value function estimate on the
# observation.
if config["use_gae"]:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
def value(**input_dict):
input_dict = SampleBatch(input_dict)
input_dict = self._lazy_tensor_dict(input_dict)
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0].item()
# When not doing GAE, we do not require the value function's output.
else:
def value(*args, **kwargs):
return 0.0
self._value = value
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
@ -133,7 +190,7 @@ A3CTorchPolicy = build_policy_class(
loss_fn=actor_critic_loss,
stats_fn=stats,
postprocess_fn=compute_gae_for_sample_batch,
extra_action_out_fn=model_value_predictions,
extra_action_out_fn=vf_preds_fetches,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=torch_optimizer,
before_loss_init=setup_mixins,

View file

@ -5,9 +5,9 @@ from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin, \
vf_preds_fetches
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
ValueNetworkMixin
from ray.rllib.utils.torch_utils import apply_grad_clipping
from ray.rllib.utils.framework import try_import_torch

View file

@ -2,7 +2,7 @@ import gym
from typing import Dict
import ray
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy.policy_template import build_policy_class

View file

@ -4,10 +4,10 @@ import logging
from typing import Tuple, Type
import ray
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \
maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2

View file

@ -16,8 +16,7 @@ from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \
choose_optimizer
from ray.rllib.agents.ppo.appo_tf_policy import make_appo_model, \
postprocess_trajectory
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin, \
KLCoeffMixin
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import \
@ -281,6 +280,30 @@ def add_values(policy, input_dict, state_batches, model, action_dist):
return out
class KLCoeffMixin:
"""Assigns the `update_kl()` method to the PPOPolicy.
This is used in PPO's execution plan (see ppo.py) for updating the KL
coefficient after each learning step based on `config.kl_target` and
the measured KL value (from the train_batch).
"""
def __init__(self, config):
# The current KL value (as python float).
self.kl_coeff = config["kl_coeff"]
# Constant target value.
self.kl_target = config["kl_target"]
def update_kl(self, sampled_kl):
# Update the current KL value based on the recently measured value.
if sampled_kl > 2.0 * self.kl_target:
self.kl_coeff *= 1.5
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeff *= 0.5
# Return the current KL value.
return self.kl_coeff
def setup_early_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict):

View file

@ -1,7 +1,3 @@
"""
PyTorch policy class used for PPO.
"""
import gym
import logging
from typing import Dict, List, Type, Union
@ -10,185 +6,167 @@ from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
LearningRateSchedule
LearningRateSchedule, TorchPolicy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import apply_grad_clipping, \
explained_variance, sequence_mask
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
from ray.rllib.utils.typing import TensorType
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
def ppo_surrogate_loss(
policy: Policy, model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss for Proximal Policy Objective.
class PPOTorchPolicy(TorchPolicy, LearningRateSchedule, EntropyCoeffSchedule):
"""PyTorch policy class used with PPOTrainer."""
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
setup_config(self, observation_space, action_space, config)
Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
logits, state = model(train_batch)
curr_action_dist = dist_class(logits, model)
TorchPolicy.__init__(
self,
observation_space,
action_space,
config,
max_seq_len=config["model"]["max_seq_len"])
# RNN case: Mask away 0-padded chunks at end of time axis.
if state:
B = len(train_batch[SampleBatch.SEQ_LENS])
max_seq_len = logits.shape[0] // B
mask = sequence_mask(
train_batch[SampleBatch.SEQ_LENS],
max_seq_len,
time_major=model.is_time_major())
mask = torch.reshape(mask, [-1])
num_valid = torch.sum(mask)
EntropyCoeffSchedule.__init__(self, config["entropy_coeff"],
config["entropy_coeff_schedule"])
LearningRateSchedule.__init__(self, config["lr"],
config["lr_schedule"])
def reduce_mean_valid(t):
return torch.sum(t[mask]) / num_valid
# non-RNN case: No masking.
else:
mask = None
reduce_mean_valid = torch.mean
prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
model)
logp_ratio = torch.exp(
curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
train_batch[SampleBatch.ACTION_LOGP])
action_kl = prev_action_dist.kl(curr_action_dist)
mean_kl_loss = reduce_mean_valid(action_kl)
curr_entropy = curr_action_dist.entropy()
mean_entropy = reduce_mean_valid(curr_entropy)
surrogate_loss = torch.min(
train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
train_batch[Postprocessing.ADVANTAGES] * torch.clamp(
logp_ratio, 1 - policy.config["clip_param"],
1 + policy.config["clip_param"]))
mean_policy_loss = reduce_mean_valid(-surrogate_loss)
# Compute a value function loss.
if policy.config["use_critic"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
value_fn_out = model.value_function()
vf_loss1 = torch.pow(
value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_clipped = prev_value_fn_out + torch.clamp(
value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
policy.config["vf_clip_param"])
vf_loss2 = torch.pow(
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_loss = torch.max(vf_loss1, vf_loss2)
mean_vf_loss = reduce_mean_valid(vf_loss)
# Ignore the value function.
else:
vf_loss = mean_vf_loss = 0.0
total_loss = reduce_mean_valid(-surrogate_loss +
policy.kl_coeff * action_kl +
policy.config["vf_loss_coeff"] * vf_loss -
policy.entropy_coeff * curr_entropy)
# Store values for stats function in model (tower), such that for
# multi-GPU, we do not override them during the parallel loss phase.
model.tower_stats["total_loss"] = total_loss
model.tower_stats["mean_policy_loss"] = mean_policy_loss
model.tower_stats["mean_vf_loss"] = mean_vf_loss
model.tower_stats["vf_explained_var"] = explained_variance(
train_batch[Postprocessing.VALUE_TARGETS], model.value_function())
model.tower_stats["mean_entropy"] = mean_entropy
model.tower_stats["mean_kl_loss"] = mean_kl_loss
return total_loss
def kl_and_loss_stats(policy: Policy,
train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Stats function for PPO. Returns a dict with important KL and loss stats.
Args:
policy (Policy): The Policy to generate stats for.
train_batch (SampleBatch): The SampleBatch (already) used for training.
Returns:
Dict[str, TensorType]: The stats dict.
"""
return {
"cur_kl_coeff": policy.kl_coeff,
"cur_lr": policy.cur_lr,
"total_loss": torch.mean(
torch.stack(policy.get_tower_stats("total_loss"))),
"policy_loss": torch.mean(
torch.stack(policy.get_tower_stats("mean_policy_loss"))),
"vf_loss": torch.mean(
torch.stack(policy.get_tower_stats("mean_vf_loss"))),
"vf_explained_var": torch.mean(
torch.stack(policy.get_tower_stats("vf_explained_var"))),
"kl": torch.mean(torch.stack(policy.get_tower_stats("mean_kl_loss"))),
"entropy": torch.mean(
torch.stack(policy.get_tower_stats("mean_entropy"))),
"entropy_coeff": policy.entropy_coeff,
}
def vf_preds_fetches(
policy: Policy, input_dict: Dict[str, TensorType],
state_batches: List[TensorType], model: ModelV2,
action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
"""Defines extra fetches per action computation.
Args:
policy (Policy): The Policy to perform the extra action fetch on.
input_dict (Dict[str, TensorType]): The input dict used for the action
computing forward pass.
state_batches (List[TensorType]): List of state tensors (empty for
non-RNNs).
model (ModelV2): The Model object of the Policy.
action_dist (TorchDistributionWrapper): The instantiated distribution
object, resulting from the model's outputs and the given
distribution class.
Returns:
Dict[str, TensorType]: Dict with extra tf fetches to perform per
action computation.
"""
# Return value function outputs. VF estimates will hence be added to the
# SampleBatches produced by the sampler(s) to generate the train batches
# going into the loss function.
return {
SampleBatch.VF_PREDS: model.value_function(),
}
class KLCoeffMixin:
"""Assigns the `update_kl()` method to the PPOPolicy.
This is used in PPO's execution plan (see ppo.py) for updating the KL
coefficient after each learning step based on `config.kl_target` and
the measured KL value (from the train_batch).
"""
def __init__(self, config):
# The current KL value (as python float).
self.kl_coeff = config["kl_coeff"]
self.kl_coeff = self.config["kl_coeff"]
# Constant target value.
self.kl_target = config["kl_target"]
self.kl_target = self.config["kl_target"]
# TODO: Don't require users to call this manually.
self._initialize_loss_from_dummy_batch()
@override(TorchPolicy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
# Do all post-processing always with no_grad().
# Not using this here will introduce a memory leak
# in torch (issue #6962).
# TODO: no_grad still necessary?
with torch.no_grad():
return compute_gae_for_sample_batch(self, sample_batch,
other_agent_batches, episode)
# TODO: Add method to Policy base class (as the new way of defining loss
# functions (instead of passing 'loss` to the super's constructor)).
@override(TorchPolicy)
def loss(self, model: ModelV2, dist_class: Type[ActionDistribution],
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss for Proximal Policy Objective.
Args:
model: The Model to calculate the loss for.
dist_class: The action distr. class.
train_batch: The training data.
Returns:
The PPO loss tensor given the input batch.
"""
logits, state = model(train_batch)
curr_action_dist = dist_class(logits, model)
# RNN case: Mask away 0-padded chunks at end of time axis.
if state:
B = len(train_batch[SampleBatch.SEQ_LENS])
max_seq_len = logits.shape[0] // B
mask = sequence_mask(
train_batch[SampleBatch.SEQ_LENS],
max_seq_len,
time_major=model.is_time_major())
mask = torch.reshape(mask, [-1])
num_valid = torch.sum(mask)
def reduce_mean_valid(t):
return torch.sum(t[mask]) / num_valid
# non-RNN case: No masking.
else:
mask = None
reduce_mean_valid = torch.mean
prev_action_dist = dist_class(
train_batch[SampleBatch.ACTION_DIST_INPUTS], model)
logp_ratio = torch.exp(
curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
train_batch[SampleBatch.ACTION_LOGP])
action_kl = prev_action_dist.kl(curr_action_dist)
mean_kl_loss = reduce_mean_valid(action_kl)
curr_entropy = curr_action_dist.entropy()
mean_entropy = reduce_mean_valid(curr_entropy)
surrogate_loss = torch.min(
train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
train_batch[Postprocessing.ADVANTAGES] * torch.clamp(
logp_ratio, 1 - self.config["clip_param"],
1 + self.config["clip_param"]))
mean_policy_loss = reduce_mean_valid(-surrogate_loss)
# Compute a value function loss.
if self.config["use_critic"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
value_fn_out = model.value_function()
vf_loss1 = torch.pow(
value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_clipped = prev_value_fn_out + torch.clamp(
value_fn_out - prev_value_fn_out,
-self.config["vf_clip_param"], self.config["vf_clip_param"])
vf_loss2 = torch.pow(
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_loss = torch.max(vf_loss1, vf_loss2)
mean_vf_loss = reduce_mean_valid(vf_loss)
# Ignore the value function.
else:
vf_loss = mean_vf_loss = 0.0
total_loss = reduce_mean_valid(-surrogate_loss +
self.kl_coeff * action_kl +
self.config["vf_loss_coeff"] * vf_loss -
self.entropy_coeff * curr_entropy)
# Store values for stats function in model (tower), such that for
# multi-GPU, we do not override them during the parallel loss phase.
model.tower_stats["total_loss"] = total_loss
model.tower_stats["mean_policy_loss"] = mean_policy_loss
model.tower_stats["mean_vf_loss"] = mean_vf_loss
model.tower_stats["vf_explained_var"] = explained_variance(
train_batch[Postprocessing.VALUE_TARGETS], model.value_function())
model.tower_stats["mean_entropy"] = mean_entropy
model.tower_stats["mean_kl_loss"] = mean_kl_loss
return total_loss
def _value(self, **input_dict):
# When doing GAE, we need the value function estimate on the
# observation.
if self.config["use_gae"]:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
input_dict = self._lazy_tensor_dict(input_dict)
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0].item()
# When not doing GAE, we do not require the value function's output.
else:
return 0.0
def update_kl(self, sampled_kl):
# Update the current KL value based on the recently measured value.
@ -199,75 +177,56 @@ class KLCoeffMixin:
# Return the current KL value.
return self.kl_coeff
# TODO: Make this an event-style subscription (e.g.:
# "after_actions_computed").
@override(TorchPolicy)
def extra_action_out(self, input_dict, state_batches, model, action_dist):
# Return value function outputs. VF estimates will hence be added to
# the SampleBatches produced by the sampler(s) to generate the train
# batches going into the loss function.
return {
SampleBatch.VF_PREDS: model.value_function(),
}
class ValueNetworkMixin:
"""Assigns the `_value()` method to the PPOPolicy.
# TODO: Make this an event-style subscription (e.g.:
# "after_gradients_computed").
@override(TorchPolicy)
def extra_grad_process(self, local_optimizer, loss):
return apply_grad_clipping(self, local_optimizer, loss)
This way, Policy can call `_value()` to get the current VF estimate on a
single(!) observation (as done in `postprocess_trajectory_fn`).
Note: When doing this, an actual forward pass is being performed.
This is different from only calling `model.value_function()`, where
the result of the most recent forward pass is being used to return an
already calculated tensor.
"""
# TODO: Make this an event-style subscription (e.g.:
# "after_losses_computed").
@override(TorchPolicy)
def extra_grad_info(self,
train_batch: SampleBatch) -> Dict[str, TensorType]:
return convert_to_numpy({
"cur_kl_coeff": self.kl_coeff,
"cur_lr": self.cur_lr,
"total_loss": torch.mean(
torch.stack(self.get_tower_stats("total_loss"))),
"policy_loss": torch.mean(
torch.stack(self.get_tower_stats("mean_policy_loss"))),
"vf_loss": torch.mean(
torch.stack(self.get_tower_stats("mean_vf_loss"))),
"vf_explained_var": torch.mean(
torch.stack(self.get_tower_stats("vf_explained_var"))),
"kl": torch.mean(
torch.stack(self.get_tower_stats("mean_kl_loss"))),
"entropy": torch.mean(
torch.stack(self.get_tower_stats("mean_entropy"))),
"entropy_coeff": self.entropy_coeff,
})
def __init__(self, obs_space, action_space, config):
# When doing GAE, we need the value function estimate on the
# observation.
if config["use_gae"]:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
def value(**input_dict):
input_dict = SampleBatch(input_dict)
input_dict = self._lazy_tensor_dict(input_dict)
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0].item()
# When not doing GAE, we do not require the value function's output.
else:
def value(*args, **kwargs):
return 0.0
self._value = value
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
"""Call all mixin classes' constructors before PPOPolicy initialization.
Args:
policy (Policy): The Policy object.
obs_space (gym.spaces.Space): The Policy's observation space.
action_space (gym.spaces.Space): The Policy's action space.
config (TrainerConfigDict): The Policy's config.
"""
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
KLCoeffMixin.__init__(policy, config)
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
config["entropy_coeff_schedule"])
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
PPOTorchPolicy = build_policy_class(
name="PPOTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
loss_fn=ppo_surrogate_loss,
stats_fn=kl_and_loss_stats,
extra_action_out_fn=vf_preds_fetches,
postprocess_fn=compute_gae_for_sample_batch,
extra_grad_process_fn=apply_grad_clipping,
before_init=setup_config,
before_loss_init=setup_mixins,
mixins=[
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
ValueNetworkMixin
],
)
# TODO: Make lr-schedule and entropy-schedule Plugin-style functionalities
# that can be added (via the config) to any Trainer/Policy.
@override(TorchPolicy)
def on_global_var_update(self, global_vars):
super().on_global_var_update(global_vars)
if self._lr_schedule:
self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
for opt in self._optimizers:
for p in opt.param_groups:
p["lr"] = self.cur_lr
if self._entropy_coeff_schedule is not None:
self.entropy_coeff = self._entropy_coeff_schedule.value(
global_vars["timestep"])

View file

@ -7,8 +7,7 @@ from ray.rllib.agents.callbacks import DefaultCallbacks
import ray.rllib.agents.ppo as ppo
from ray.rllib.agents.ppo.ppo_tf_policy import ppo_surrogate_loss as \
ppo_surrogate_loss_tf
from ray.rllib.agents.ppo.ppo_torch_policy import ppo_surrogate_loss as \
ppo_surrogate_loss_torch
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical
@ -272,8 +271,8 @@ class TestPPO(unittest.TestCase):
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
train_batch)
elif fw == "torch":
ppo_surrogate_loss_torch(policy, policy.model,
TorchCategorical, train_batch)
PPOTorchPolicy.loss(policy, policy.model, policy.dist_class,
train_batch)
vars = policy.model.variables() if fw != "torch" else \
list(policy.model.parameters())

View file

@ -123,7 +123,7 @@ class AlphaZeroPolicy(TorchPolicy):
sample_batch["t"])
return sample_batch
@override(Policy)
@override(TorchPolicy)
def learn_on_batch(self, postprocessed_batch):
train_batch = self._lazy_tensor_dict(postprocessed_batch)

View file

@ -23,9 +23,9 @@ def adjust_nstep(n_step: int, gamma: float, batch: SampleBatch) -> None:
n is truncated to fit in the traj length.
Args:
n_step (int): The number of steps to look ahead and adjust.
gamma (float): The discount factor.
batch (SampleBatch): The SampleBatch to adjust (in place).
n_step: The number of steps to look ahead and adjust.
gamma: The discount factor.
batch: The SampleBatch to adjust (in place).
Examples:
n-step=3
@ -73,21 +73,19 @@ def compute_advantages(rollout: SampleBatch,
lambda_: float = 1.0,
use_gae: bool = True,
use_critic: bool = True):
"""
Given a rollout, compute its value targets and the advantages.
"""Given a rollout, compute its value targets and the advantages.
Args:
rollout (SampleBatch): SampleBatch of a single trajectory.
last_r (float): Value estimation for last observation.
gamma (float): Discount factor.
lambda_ (float): Parameter for GAE.
use_gae (bool): Using Generalized Advantage Estimation.
use_critic (bool): Whether to use critic (value estimates). Setting
rollout: SampleBatch of a single trajectory.
last_r: Value estimation for last observation.
gamma: Discount factor.
lambda_: Parameter for GAE.
use_gae: Using Generalized Advantage Estimation.
use_critic: Whether to use critic (value estimates). Setting
this to False will use 0 as baseline.
Returns:
SampleBatch (SampleBatch): Object with experience from rollout and
processed rewards.
SampleBatch with experience from rollout and processed rewards.
"""
assert SampleBatch.VF_PREDS in rollout or not use_critic, \
@ -147,17 +145,16 @@ def compute_gae_for_sample_batch(
New columns can be added to sample_batch and existing ones may be altered.
Args:
policy (Policy): The Policy used to generate the trajectory
(`sample_batch`)
sample_batch (SampleBatch): The SampleBatch to postprocess.
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated.
policy: The Policy used to generate the trajectory (`sample_batch`)
sample_batch: The SampleBatch to postprocess.
other_agent_batches: Optional dict of AgentIDs mapping to other
agents' trajectory data (from the same episode).
NOTE: The other agents use the same policy.
episode: Optional multi-agent episode object in which the agents
operated.
Returns:
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
The postprocessed, modified SampleBatch (or a new one).
"""
# Trajectory is actually complete -> last r=0.0.
@ -193,11 +190,11 @@ def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
Args:
gamma (float): The discount factor gamma.
gamma: The discount factor gamma.
Returns:
np.ndarray: The sequence containing the discounted cumulative sums
for each individual reward in `x` till the end of the trajectory.
The sequence containing the discounted cumulative sums
for each individual reward in `x` till the end of the trajectory.
Examples:
>>> x = np.array([0.0, 1.0, 2.0, 3.0])

View file

@ -185,8 +185,10 @@ if __name__ == "__main__":
# start with all zeros as state
num_transformers = config["model"][
"attention_num_transformer_units"]
attention_dim = config["model"]["attention_dim"]
memory = config["model"]["attention_memory_inference"]
init_state = state = [
np.zeros([100, 32], np.float32)
np.zeros([memory, attention_dim], np.float32)
for _ in range(num_transformers)
]
# run one iteration until done

View file

@ -20,11 +20,12 @@ import os
import ray
from ray import tune
from ray.rllib.agents.maml.maml_torch_policy import KLCoeffMixin as \
TorchKLCoeffMixin
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy, KLCoeffMixin, \
ppo_surrogate_loss as tf_loss
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy, \
KLCoeffMixin as TorchKLCoeffMixin, ppo_surrogate_loss as torch_loss
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.examples.env.two_step_game import TwoStepGame
@ -36,6 +37,7 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule, \
EntropyCoeffSchedule
from ray.rllib.policy.torch_policy import LearningRateSchedule as TorchLR, \
EntropyCoeffSchedule as TorchEntropyCoeffSchedule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
@ -143,7 +145,8 @@ def centralized_critic_postprocessing(policy,
# Copied from PPO but optimizing the central value function.
def loss_with_central_critic(policy, model, dist_class, train_batch):
CentralizedValueMixin.__init__(policy)
func = tf_loss if not policy.config["framework"] == "torch" else torch_loss
func = tf_loss if not policy.config["framework"] == "torch" \
else PPOTorchPolicy.loss
vf_saved = model.value_function
model.value_function = lambda: policy.model.central_value_function(
@ -194,15 +197,23 @@ CCPPOTFPolicy = PPOTFPolicy.with_updates(
CentralizedValueMixin
])
CCPPOTorchPolicy = PPOTorchPolicy.with_updates(
name="CCPPOTorchPolicy",
postprocess_fn=centralized_critic_postprocessing,
loss_fn=loss_with_central_critic,
before_init=setup_torch_mixins,
mixins=[
TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin,
CentralizedValueMixin
])
class CCPPOTorchPolicy(PPOTorchPolicy):
def __init__(self, observation_space, action_space, config):
super().__init__(observation_space, action_space, config)
self.compute_central_vf = self.model.central_value_function
@override(PPOTorchPolicy)
def loss(self, model, dist_class, train_batch):
return loss_with_central_critic(self, model, dist_class, train_batch)
@override(PPOTorchPolicy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
return centralized_critic_postprocessing(self, sample_batch,
other_agent_batches, episode)
def get_policy_class(config):

View file

@ -326,7 +326,18 @@ def build_eager_tf_policy(
self._re_trace_counter = 0
self._loss_initialized = False
self._loss = loss_fn
# To ensure backward compatibility:
# Old way: If `loss` provided here, use as-is (as a function).
if loss_fn is not None:
self._loss = loss_fn
# New way: Convert the overridden `self.loss` into a plain
# function, so it can be called the same way as `loss` would
# be, ensuring backward compatibility.
elif self.loss.__func__.__qualname__ != "Policy.loss":
self._loss = self.loss.__func__
# `loss` not provided nor overridden from Policy -> Set to None.
else:
self._loss = None
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
callable(get_batch_divisibility_req) else \
@ -828,7 +839,7 @@ def build_eager_tf_policy(
# Calculate the loss(es) inside a tf GradientTape.
with tf.GradientTape(persistent=compute_gradients_fn is not None) \
as tape:
losses = loss_fn(self, self.model, self.dist_class, samples)
losses = self._loss(self, self.model, self.dist_class, samples)
losses = force_list(losses)
# User provided a compute_gradients_fn.

View file

@ -7,10 +7,13 @@ import numpy as np
import tree # pip install dm_tree
from typing import Dict, List, Optional, Type, TYPE_CHECKING
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, \
OverrideToImplementCustomLogic
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_tf, try_import_torch
@ -88,8 +91,7 @@ class Policy(metaclass=ABCMeta):
"""Initializes a Policy instance.
Args:
observation_space: Observation space of the
policy.
observation_space: Observation space of the policy.
action_space: Action space of the policy.
config: A complete Trainer/Policy config dict. For the default
config keys and values, see rllib/trainer/trainer.py.
@ -404,6 +406,25 @@ class Policy(metaclass=ABCMeta):
# The default implementation just returns the same, unaltered batch.
return sample_batch
@ExperimentalAPI
@OverrideToImplementCustomLogic
def loss(self, model: ModelV2, dist_class: ActionDistribution,
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
"""Loss function for this Policy.
Override this method in order to implement custom loss computations.
Args:
model: The model to calculate the loss(es).
dist_class: The action distribution class to sample actions
from the model's outputs.
train_batch: The input batch on which to calculate the loss.
Returns:
Either a single loss tensor or a list of loss tensors.
"""
raise NotImplementedError
@DeveloperAPI
def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]:
"""Perform one learning update, given `samples`.
@ -620,8 +641,8 @@ class Policy(metaclass=ABCMeta):
"""Called on an update to global vars.
Args:
global_vars (Dict[str, TensorType]): Global variables by str key,
broadcast from the driver.
global_vars: Global variables by str key, broadcast from the
driver.
"""
# Store the current global time step (sum over all policies' sample
# steps).
@ -632,7 +653,7 @@ class Policy(metaclass=ABCMeta):
"""Export Policy checkpoint to local directory.
Args:
export_dir (str): Local writable directory.
export_dir: Local writable directory.
"""
raise NotImplementedError
@ -646,8 +667,8 @@ class Policy(metaclass=ABCMeta):
implementations for more details.
Args:
export_dir (str): Local writable directory.
onnx (int): If given, will export model in ONNX format. The
export_dir: Local writable directory.
onnx: If given, will export model in ONNX format. The
value of this parameter set the ONNX OpSet version to use.
"""
raise NotImplementedError
@ -990,6 +1011,10 @@ class Policy(metaclass=ABCMeta):
vr["state_out_{}".format(i)] = ViewRequirement(
space=space, used_for_training=True)
@DeveloperAPI
def __repr__(self):
return type(self).__name__
@Deprecated(new="get_exploration_state", error=False)
def get_exploration_info(self) -> Dict[str, TensorType]:
return self.get_exploration_state()

View file

@ -25,7 +25,7 @@ jax, _ = try_import_jax()
torch, _ = try_import_torch()
# TODO: (sven) Unify this with `build_tf_policy` as well.
# TODO: Deprecate in favor of directly sub-classing from TorchPolicy.
@DeveloperAPI
def build_policy_class(
name: str,

View file

@ -8,10 +8,11 @@ import os
import threading
import time
import tree # pip install dm_tree
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, \
TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, \
Union, TYPE_CHECKING
import ray
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
@ -19,7 +20,7 @@ from ray.rllib.policy.policy import Policy
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import force_list, NullContextManager
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.numpy import convert_to_numpy
@ -27,8 +28,8 @@ from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \
TensorStructType, TrainerConfigDict
from ray.rllib.utils.typing import GradInfoDict, ModelGradients, \
ModelWeights, TensorType, TensorStructType, TrainerConfigDict
if TYPE_CHECKING:
from ray.rllib.evaluation import Episode # noqa
@ -49,11 +50,12 @@ class TorchPolicy(Policy):
action_space: gym.spaces.Space,
config: TrainerConfigDict,
*,
model: ModelV2,
loss: Callable[[
model: Optional[TorchModelV2] = None,
loss: Optional[Callable[[
Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
], Union[TensorType, List[TensorType]]],
action_distribution_class: Type[TorchDistributionWrapper],
], Union[TensorType, List[TensorType]]]] = None,
action_distribution_class: Optional[Type[
TorchDistributionWrapper]] = None,
action_sampler_fn: Optional[Callable[[
TensorType, List[TensorType]
], Tuple[TensorType, TensorType]]] = None,
@ -113,7 +115,7 @@ class TorchPolicy(Policy):
get_batch_divisibility_req: Optional callable that returns the
divisibility requirement for sample batches given the Policy.
"""
self.framework = "torch"
self.framework = config["framework"] = "torch"
super().__init__(observation_space, action_space, config)
# Create multi-GPU model towers, if necessary.
@ -128,6 +130,19 @@ class TorchPolicy(Policy):
# - In case of just one device (1 (fake or real) GPU or 1 CPU), no
# parallelization will be done.
# If no Model is provided, build a default one here.
if model is None:
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], framework=self.framework)
model = ModelCatalog.get_model_v2(
obs_space=self.observation_space,
action_space=self.action_space,
num_outputs=logit_dim,
model_config=self.config["model"],
framework=self.framework)
if action_distribution_class is None:
action_distribution_class = dist_class
# Get devices to build the graph on.
worker_idx = self.config.get("worker_index", 0)
if not config["_fake_gpus"] and \
@ -213,7 +228,18 @@ class TorchPolicy(Policy):
self.exploration = self._create_exploration()
self.unwrapped_model = model # used to support DistributedDataParallel
self._loss = loss
# To ensure backward compatibility:
# Old way: If `loss` provided here, use as-is (as a function).
if loss is not None:
self._loss = loss
# New way: Convert the overridden `self.loss` into a plain function,
# so it can be called the same way as `loss` would be, ensuring
# backward compatibility.
elif self.loss.__func__.__qualname__ != "Policy.loss":
self._loss = self.loss.__func__
# `loss` not provided nor overridden from Policy -> Set to None.
else:
self._loss = None
self._optimizers = force_list(self.optimizer())
# Store, which params (by index within the model's list of
# parameters) should be updated per optimizer.
@ -616,11 +642,11 @@ class TorchPolicy(Policy):
Returns:
The list of stats tensor (structs) of all towers, copied to this
Policy's device.
Policy's device.
Raises:
AssertionError: If the `stats_name` cannot be found in any one
of the tower's `tower_stats` dicts.
of the tower's `tower_stats` dicts.
"""
data = []
for tower in self.model_gpu_towers:
@ -697,7 +723,7 @@ class TorchPolicy(Policy):
@DeveloperAPI
def extra_grad_process(self, optimizer: "torch.optim.Optimizer",
loss: TensorType):
loss: TensorType) -> Dict[str, TensorType]:
"""Called after each optimizer.zero_grad() + loss.backward() call.
Called for each self._optimizers/loss-value pair.
@ -705,22 +731,21 @@ class TorchPolicy(Policy):
E.g. for gradient clipping.
Args:
optimizer (torch.optim.Optimizer): A torch optimizer object.
loss (TensorType): The loss tensor associated with the optimizer.
optimizer: A torch optimizer object.
loss: The loss tensor associated with the optimizer.
Returns:
Dict[str, TensorType]: An dict with information on the gradient
processing step.
An dict with information on the gradient processing step.
"""
return {}
@DeveloperAPI
def extra_compute_grad_fetches(self) -> Dict[str, any]:
def extra_compute_grad_fetches(self) -> Dict[str, Any]:
"""Extra values to fetch and return from compute_gradients().
Returns:
Dict[str, any]: Extra fetch dict to be added to the fetch dict
of the compute_gradients call.
Extra fetch dict to be added to the fetch dict of the
`compute_gradients` call.
"""
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
@ -732,15 +757,15 @@ class TorchPolicy(Policy):
"""Returns dict of extra info to include in experience batch.
Args:
input_dict (Dict[str, TensorType]): Dict of model input tensors.
state_batches (List[TensorType]): List of state tensors.
model (TorchModelV2): Reference to the model object.
action_dist (TorchDistributionWrapper): Torch action dist object
input_dict: Dict of model input tensors.
state_batches: List of state tensors.
model: Reference to the model object.
action_dist: Torch action dist object
to get log-probs (e.g. for already sampled actions).
Returns:
Dict[str, TensorType]: Extra outputs to return in a
compute_actions() call (3rd return value).
Extra outputs to return in a `compute_actions_from_input_dict()`
call (3rd return value).
"""
return {}
@ -750,12 +775,11 @@ class TorchPolicy(Policy):
"""Return dict of extra grad info.
Args:
train_batch (SampleBatch): The training batch for which to produce
train_batch: The training batch for which to produce
extra grad info for.
Returns:
Dict[str, TensorType]: The info dict carrying grad info per str
key.
The info dict carrying grad info per str key.
"""
return {}
@ -766,8 +790,7 @@ class TorchPolicy(Policy):
"""Custom the local PyTorch optimizer(s) to use.
Returns:
Union[List[torch.optim.Optimizer], torch.optim.Optimizer]:
The local PyTorch optimizer(s) to use for this Policy.
The local PyTorch optimizer(s) to use for this Policy.
"""
if hasattr(self, "config"):
optimizers = [
@ -789,7 +812,9 @@ class TorchPolicy(Policy):
Creates a TorchScript model and saves it.
Args:
export_dir (str): Local writable directory or filename.
export_dir: Local writable directory or filename.
onnx: If given, will export model in ONNX format. The
value of this parameter set the ONNX OpSet version to use.
"""
self._lazy_tensor_dict(self._dummy_batch)
# Provide dummy state inputs if not an RNN (torch cannot jit with
@ -852,8 +877,7 @@ class TorchPolicy(Policy):
"""Shared forward pass logic (w/ and w/o trajectory view API).
Returns:
Tuple:
- actions, state_out, extra_fetches, logp.
A tuple consisting of a) actions, b) state_out, c) extra_fetches.
"""
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
@ -956,7 +980,9 @@ class TorchPolicy(Policy):
convert_to_torch_tensor, device=device or self.device))
return postprocessed_batch
def _multi_gpu_parallel_grad_calc(self, sample_batches):
def _multi_gpu_parallel_grad_calc(
self, sample_batches: List[SampleBatch]
) -> List[Tuple[List[TensorType], GradInfoDict]]:
"""Performs a parallelized loss and gradient calculation over the batch.
Splits up the given train batch into n shards (n=number of this
@ -965,12 +991,12 @@ class TorchPolicy(Policy):
(self.model_gpu_towers). Then returns each tower's outputs.
Args:
sample_batches (List[SampleBatch]): A list of SampleBatch shards to
sample_batches: A list of SampleBatch shards to
calculate loss and gradients for.
Returns:
List[Tuple[List[TensorType], StatsDict]]: A list (one item per
device) of 2-tuples with 1) gradient list and 2) stats dict.
A list (one item per device) of 2-tuples, each with 1) gradient
list and 2) grad info dict.
"""
assert len(self.model_gpu_towers) == len(sample_batches)
lock = threading.Lock()

View file

@ -65,5 +65,35 @@ def ExperimentalAPI(obj):
return obj
def OverrideToImplementCustomLogic(obj):
"""Users should override this in their sub-classes to implement custom logic.
Used in Trainer and Policy to tag methods that need overriding, e.g.
`Policy.loss()`.
"""
return obj
def OverrideToImplementCustomLogic_CallToSuperRecommended(obj):
"""Users should override this in their sub-classes to implement custom logic.
Thereby, it is recommended (but not required) to call the super-class'
corresponding method.
Used in Trainer and Policy to tag methods that need overriding, but the
super class' method should still be called, e.g.
`Trainer.setup()`.
Examples:
>>> @overrides(Trainable)
... @OverrideToImplementCustomLogic_CallToSuperRecommended
... def setup(self, config):
... # implement custom setup logic here ...
... super().setup(config)
... # ... or here (after having called super()'s setup method.
"""
return obj
# Backward compatibility.
Deprecated = Deprecated