mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Revert "Revert [RLlib] POC: Deprecate build_policy
(policy template) for torch only; PPOTorchPolicy (#20061) (#20399)" (#20417)
This reverts commit 90dc5460d4
.
This commit is contained in:
parent
6504ad6bb2
commit
f82880eda1
18 changed files with 485 additions and 345 deletions
|
@ -117,6 +117,7 @@ class Trainable:
|
|||
self._stderr_file = stderr_file
|
||||
|
||||
start_time = time.time()
|
||||
self._local_ip = self.get_current_ip()
|
||||
self.setup(copy.deepcopy(self.config))
|
||||
setup_time = time.time() - start_time
|
||||
if setup_time > SETUP_TIME_THRESHOLD:
|
||||
|
@ -124,7 +125,6 @@ class Trainable:
|
|||
"trainable is slow to initialize, consider setting "
|
||||
"reuse_actors=True to reduce actor creation "
|
||||
"overheads.".format(setup_time))
|
||||
self._local_ip = self.get_current_ip()
|
||||
log_sys_usage = self.config.get("log_sys_usage", False)
|
||||
self._monitor = UtilMonitor(start=log_sys_usage)
|
||||
|
||||
|
|
|
@ -1654,7 +1654,7 @@ py_test(
|
|||
tags = ["team:ml", "examples", "examples_A"],
|
||||
size = "medium",
|
||||
srcs = ["examples/attention_net.py"],
|
||||
args = ["--as-test", "--stop-reward=60"]
|
||||
args = ["--as-test", "--stop-reward=70"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -1663,7 +1663,7 @@ py_test(
|
|||
tags = ["team:ml", "examples", "examples_A"],
|
||||
size = "medium",
|
||||
srcs = ["examples/attention_net.py"],
|
||||
args = ["--as-test", "--stop-reward=60", "--framework torch"]
|
||||
args = ["--as-test", "--stop-reward=70", "--framework torch"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import gym
|
||||
from typing import Optional, Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
|
@ -97,11 +97,33 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
|||
}
|
||||
|
||||
|
||||
def model_value_predictions(
|
||||
policy: Policy, input_dict: Dict[str, TensorType], state_batches,
|
||||
model: ModelV2,
|
||||
action_dist: ActionDistribution) -> Dict[str, TensorType]:
|
||||
return {SampleBatch.VF_PREDS: model.value_function()}
|
||||
def vf_preds_fetches(
|
||||
policy: Policy, input_dict: Dict[str, TensorType],
|
||||
state_batches: List[TensorType], model: ModelV2,
|
||||
action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
|
||||
"""Defines extra fetches per action computation.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to perform the extra action fetch on.
|
||||
input_dict (Dict[str, TensorType]): The input dict used for the action
|
||||
computing forward pass.
|
||||
state_batches (List[TensorType]): List of state tensors (empty for
|
||||
non-RNNs).
|
||||
model (ModelV2): The Model object of the Policy.
|
||||
action_dist (TorchDistributionWrapper): The instantiated distribution
|
||||
object, resulting from the model's outputs and the given
|
||||
distribution class.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: Dict with extra tf fetches to perform per
|
||||
action computation.
|
||||
"""
|
||||
# Return value function outputs. VF estimates will hence be added to the
|
||||
# SampleBatches produced by the sampler(s) to generate the train batches
|
||||
# going into the loss function.
|
||||
return {
|
||||
SampleBatch.VF_PREDS: model.value_function(),
|
||||
}
|
||||
|
||||
|
||||
def torch_optimizer(policy: Policy,
|
||||
|
@ -109,6 +131,41 @@ def torch_optimizer(policy: Policy,
|
|||
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
|
||||
|
||||
|
||||
class ValueNetworkMixin:
|
||||
"""Assigns the `_value()` method to the PPOPolicy.
|
||||
|
||||
This way, Policy can call `_value()` to get the current VF estimate on a
|
||||
single(!) observation (as done in `postprocess_trajectory_fn`).
|
||||
Note: When doing this, an actual forward pass is being performed.
|
||||
This is different from only calling `model.value_function()`, where
|
||||
the result of the most recent forward pass is being used to return an
|
||||
already calculated tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
# When doing GAE, we need the value function estimate on the
|
||||
# observation.
|
||||
if config["use_gae"]:
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
|
||||
def value(**input_dict):
|
||||
input_dict = SampleBatch(input_dict)
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
model_out, _ = self.model(input_dict)
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0].item()
|
||||
|
||||
# When not doing GAE, we do not require the value function's output.
|
||||
else:
|
||||
|
||||
def value(*args, **kwargs):
|
||||
return 0.0
|
||||
|
||||
self._value = value
|
||||
|
||||
|
||||
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
|
@ -133,7 +190,7 @@ A3CTorchPolicy = build_policy_class(
|
|||
loss_fn=actor_critic_loss,
|
||||
stats_fn=stats,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
extra_action_out_fn=model_value_predictions,
|
||||
extra_action_out_fn=vf_preds_fetches,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
optimizer_fn=torch_optimizer,
|
||||
before_loss_init=setup_mixins,
|
||||
|
|
|
@ -5,9 +5,9 @@ from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
|||
Postprocessing
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin, \
|
||||
vf_preds_fetches
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
|
||||
ValueNetworkMixin
|
||||
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import gym
|
|||
from typing import Dict
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin
|
||||
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
|
|
|
@ -4,10 +4,10 @@ import logging
|
|||
from typing import Tuple, Type
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
|
||||
from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \
|
||||
maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
|
|
|
@ -16,8 +16,7 @@ from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \
|
|||
choose_optimizer
|
||||
from ray.rllib.agents.ppo.appo_tf_policy import make_appo_model, \
|
||||
postprocess_trajectory
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin, \
|
||||
KLCoeffMixin
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import \
|
||||
|
@ -281,6 +280,30 @@ def add_values(policy, input_dict, state_batches, model, action_dist):
|
|||
return out
|
||||
|
||||
|
||||
class KLCoeffMixin:
|
||||
"""Assigns the `update_kl()` method to the PPOPolicy.
|
||||
|
||||
This is used in PPO's execution plan (see ppo.py) for updating the KL
|
||||
coefficient after each learning step based on `config.kl_target` and
|
||||
the measured KL value (from the train_batch).
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
# The current KL value (as python float).
|
||||
self.kl_coeff = config["kl_coeff"]
|
||||
# Constant target value.
|
||||
self.kl_target = config["kl_target"]
|
||||
|
||||
def update_kl(self, sampled_kl):
|
||||
# Update the current KL value based on the recently measured value.
|
||||
if sampled_kl > 2.0 * self.kl_target:
|
||||
self.kl_coeff *= 1.5
|
||||
elif sampled_kl < 0.5 * self.kl_target:
|
||||
self.kl_coeff *= 0.5
|
||||
# Return the current KL value.
|
||||
return self.kl_coeff
|
||||
|
||||
|
||||
def setup_early_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict):
|
||||
|
|
|
@ -1,7 +1,3 @@
|
|||
"""
|
||||
PyTorch policy class used for PPO.
|
||||
"""
|
||||
import gym
|
||||
import logging
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
|
@ -10,185 +6,167 @@ from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
|||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
|
||||
LearningRateSchedule
|
||||
LearningRateSchedule, TorchPolicy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.torch_utils import apply_grad_clipping, \
|
||||
explained_variance, sequence_mask
|
||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ppo_surrogate_loss(
|
||||
policy: Policy, model: ModelV2,
|
||||
dist_class: Type[TorchDistributionWrapper],
|
||||
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
|
||||
"""Constructs the loss for Proximal Policy Objective.
|
||||
class PPOTorchPolicy(TorchPolicy, LearningRateSchedule, EntropyCoeffSchedule):
|
||||
"""PyTorch policy class used with PPOTrainer."""
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
dist_class (Type[ActionDistribution]: The action distr. class.
|
||||
train_batch (SampleBatch): The training data.
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
|
||||
setup_config(self, observation_space, action_space, config)
|
||||
|
||||
Returns:
|
||||
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
||||
of loss tensors.
|
||||
"""
|
||||
logits, state = model(train_batch)
|
||||
curr_action_dist = dist_class(logits, model)
|
||||
TorchPolicy.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
max_seq_len=config["model"]["max_seq_len"])
|
||||
|
||||
# RNN case: Mask away 0-padded chunks at end of time axis.
|
||||
if state:
|
||||
B = len(train_batch[SampleBatch.SEQ_LENS])
|
||||
max_seq_len = logits.shape[0] // B
|
||||
mask = sequence_mask(
|
||||
train_batch[SampleBatch.SEQ_LENS],
|
||||
max_seq_len,
|
||||
time_major=model.is_time_major())
|
||||
mask = torch.reshape(mask, [-1])
|
||||
num_valid = torch.sum(mask)
|
||||
EntropyCoeffSchedule.__init__(self, config["entropy_coeff"],
|
||||
config["entropy_coeff_schedule"])
|
||||
LearningRateSchedule.__init__(self, config["lr"],
|
||||
config["lr_schedule"])
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return torch.sum(t[mask]) / num_valid
|
||||
|
||||
# non-RNN case: No masking.
|
||||
else:
|
||||
mask = None
|
||||
reduce_mean_valid = torch.mean
|
||||
|
||||
prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
||||
model)
|
||||
|
||||
logp_ratio = torch.exp(
|
||||
curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
|
||||
train_batch[SampleBatch.ACTION_LOGP])
|
||||
action_kl = prev_action_dist.kl(curr_action_dist)
|
||||
mean_kl_loss = reduce_mean_valid(action_kl)
|
||||
|
||||
curr_entropy = curr_action_dist.entropy()
|
||||
mean_entropy = reduce_mean_valid(curr_entropy)
|
||||
|
||||
surrogate_loss = torch.min(
|
||||
train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
|
||||
train_batch[Postprocessing.ADVANTAGES] * torch.clamp(
|
||||
logp_ratio, 1 - policy.config["clip_param"],
|
||||
1 + policy.config["clip_param"]))
|
||||
mean_policy_loss = reduce_mean_valid(-surrogate_loss)
|
||||
|
||||
# Compute a value function loss.
|
||||
if policy.config["use_critic"]:
|
||||
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
|
||||
value_fn_out = model.value_function()
|
||||
vf_loss1 = torch.pow(
|
||||
value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
|
||||
vf_clipped = prev_value_fn_out + torch.clamp(
|
||||
value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
|
||||
policy.config["vf_clip_param"])
|
||||
vf_loss2 = torch.pow(
|
||||
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
|
||||
vf_loss = torch.max(vf_loss1, vf_loss2)
|
||||
mean_vf_loss = reduce_mean_valid(vf_loss)
|
||||
# Ignore the value function.
|
||||
else:
|
||||
vf_loss = mean_vf_loss = 0.0
|
||||
|
||||
total_loss = reduce_mean_valid(-surrogate_loss +
|
||||
policy.kl_coeff * action_kl +
|
||||
policy.config["vf_loss_coeff"] * vf_loss -
|
||||
policy.entropy_coeff * curr_entropy)
|
||||
|
||||
# Store values for stats function in model (tower), such that for
|
||||
# multi-GPU, we do not override them during the parallel loss phase.
|
||||
model.tower_stats["total_loss"] = total_loss
|
||||
model.tower_stats["mean_policy_loss"] = mean_policy_loss
|
||||
model.tower_stats["mean_vf_loss"] = mean_vf_loss
|
||||
model.tower_stats["vf_explained_var"] = explained_variance(
|
||||
train_batch[Postprocessing.VALUE_TARGETS], model.value_function())
|
||||
model.tower_stats["mean_entropy"] = mean_entropy
|
||||
model.tower_stats["mean_kl_loss"] = mean_kl_loss
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def kl_and_loss_stats(policy: Policy,
|
||||
train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
"""Stats function for PPO. Returns a dict with important KL and loss stats.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to generate stats for.
|
||||
train_batch (SampleBatch): The SampleBatch (already) used for training.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The stats dict.
|
||||
"""
|
||||
return {
|
||||
"cur_kl_coeff": policy.kl_coeff,
|
||||
"cur_lr": policy.cur_lr,
|
||||
"total_loss": torch.mean(
|
||||
torch.stack(policy.get_tower_stats("total_loss"))),
|
||||
"policy_loss": torch.mean(
|
||||
torch.stack(policy.get_tower_stats("mean_policy_loss"))),
|
||||
"vf_loss": torch.mean(
|
||||
torch.stack(policy.get_tower_stats("mean_vf_loss"))),
|
||||
"vf_explained_var": torch.mean(
|
||||
torch.stack(policy.get_tower_stats("vf_explained_var"))),
|
||||
"kl": torch.mean(torch.stack(policy.get_tower_stats("mean_kl_loss"))),
|
||||
"entropy": torch.mean(
|
||||
torch.stack(policy.get_tower_stats("mean_entropy"))),
|
||||
"entropy_coeff": policy.entropy_coeff,
|
||||
}
|
||||
|
||||
|
||||
def vf_preds_fetches(
|
||||
policy: Policy, input_dict: Dict[str, TensorType],
|
||||
state_batches: List[TensorType], model: ModelV2,
|
||||
action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
|
||||
"""Defines extra fetches per action computation.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to perform the extra action fetch on.
|
||||
input_dict (Dict[str, TensorType]): The input dict used for the action
|
||||
computing forward pass.
|
||||
state_batches (List[TensorType]): List of state tensors (empty for
|
||||
non-RNNs).
|
||||
model (ModelV2): The Model object of the Policy.
|
||||
action_dist (TorchDistributionWrapper): The instantiated distribution
|
||||
object, resulting from the model's outputs and the given
|
||||
distribution class.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: Dict with extra tf fetches to perform per
|
||||
action computation.
|
||||
"""
|
||||
# Return value function outputs. VF estimates will hence be added to the
|
||||
# SampleBatches produced by the sampler(s) to generate the train batches
|
||||
# going into the loss function.
|
||||
return {
|
||||
SampleBatch.VF_PREDS: model.value_function(),
|
||||
}
|
||||
|
||||
|
||||
class KLCoeffMixin:
|
||||
"""Assigns the `update_kl()` method to the PPOPolicy.
|
||||
|
||||
This is used in PPO's execution plan (see ppo.py) for updating the KL
|
||||
coefficient after each learning step based on `config.kl_target` and
|
||||
the measured KL value (from the train_batch).
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
# The current KL value (as python float).
|
||||
self.kl_coeff = config["kl_coeff"]
|
||||
self.kl_coeff = self.config["kl_coeff"]
|
||||
# Constant target value.
|
||||
self.kl_target = config["kl_target"]
|
||||
self.kl_target = self.config["kl_target"]
|
||||
|
||||
# TODO: Don't require users to call this manually.
|
||||
self._initialize_loss_from_dummy_batch()
|
||||
|
||||
@override(TorchPolicy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
# Do all post-processing always with no_grad().
|
||||
# Not using this here will introduce a memory leak
|
||||
# in torch (issue #6962).
|
||||
# TODO: no_grad still necessary?
|
||||
with torch.no_grad():
|
||||
return compute_gae_for_sample_batch(self, sample_batch,
|
||||
other_agent_batches, episode)
|
||||
|
||||
# TODO: Add method to Policy base class (as the new way of defining loss
|
||||
# functions (instead of passing 'loss` to the super's constructor)).
|
||||
@override(TorchPolicy)
|
||||
def loss(self, model: ModelV2, dist_class: Type[ActionDistribution],
|
||||
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
|
||||
"""Constructs the loss for Proximal Policy Objective.
|
||||
|
||||
Args:
|
||||
model: The Model to calculate the loss for.
|
||||
dist_class: The action distr. class.
|
||||
train_batch: The training data.
|
||||
|
||||
Returns:
|
||||
The PPO loss tensor given the input batch.
|
||||
"""
|
||||
|
||||
logits, state = model(train_batch)
|
||||
curr_action_dist = dist_class(logits, model)
|
||||
|
||||
# RNN case: Mask away 0-padded chunks at end of time axis.
|
||||
if state:
|
||||
B = len(train_batch[SampleBatch.SEQ_LENS])
|
||||
max_seq_len = logits.shape[0] // B
|
||||
mask = sequence_mask(
|
||||
train_batch[SampleBatch.SEQ_LENS],
|
||||
max_seq_len,
|
||||
time_major=model.is_time_major())
|
||||
mask = torch.reshape(mask, [-1])
|
||||
num_valid = torch.sum(mask)
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return torch.sum(t[mask]) / num_valid
|
||||
|
||||
# non-RNN case: No masking.
|
||||
else:
|
||||
mask = None
|
||||
reduce_mean_valid = torch.mean
|
||||
|
||||
prev_action_dist = dist_class(
|
||||
train_batch[SampleBatch.ACTION_DIST_INPUTS], model)
|
||||
|
||||
logp_ratio = torch.exp(
|
||||
curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
|
||||
train_batch[SampleBatch.ACTION_LOGP])
|
||||
action_kl = prev_action_dist.kl(curr_action_dist)
|
||||
mean_kl_loss = reduce_mean_valid(action_kl)
|
||||
|
||||
curr_entropy = curr_action_dist.entropy()
|
||||
mean_entropy = reduce_mean_valid(curr_entropy)
|
||||
|
||||
surrogate_loss = torch.min(
|
||||
train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
|
||||
train_batch[Postprocessing.ADVANTAGES] * torch.clamp(
|
||||
logp_ratio, 1 - self.config["clip_param"],
|
||||
1 + self.config["clip_param"]))
|
||||
mean_policy_loss = reduce_mean_valid(-surrogate_loss)
|
||||
|
||||
# Compute a value function loss.
|
||||
if self.config["use_critic"]:
|
||||
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
|
||||
value_fn_out = model.value_function()
|
||||
vf_loss1 = torch.pow(
|
||||
value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
|
||||
vf_clipped = prev_value_fn_out + torch.clamp(
|
||||
value_fn_out - prev_value_fn_out,
|
||||
-self.config["vf_clip_param"], self.config["vf_clip_param"])
|
||||
vf_loss2 = torch.pow(
|
||||
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
|
||||
vf_loss = torch.max(vf_loss1, vf_loss2)
|
||||
mean_vf_loss = reduce_mean_valid(vf_loss)
|
||||
# Ignore the value function.
|
||||
else:
|
||||
vf_loss = mean_vf_loss = 0.0
|
||||
|
||||
total_loss = reduce_mean_valid(-surrogate_loss +
|
||||
self.kl_coeff * action_kl +
|
||||
self.config["vf_loss_coeff"] * vf_loss -
|
||||
self.entropy_coeff * curr_entropy)
|
||||
|
||||
# Store values for stats function in model (tower), such that for
|
||||
# multi-GPU, we do not override them during the parallel loss phase.
|
||||
model.tower_stats["total_loss"] = total_loss
|
||||
model.tower_stats["mean_policy_loss"] = mean_policy_loss
|
||||
model.tower_stats["mean_vf_loss"] = mean_vf_loss
|
||||
model.tower_stats["vf_explained_var"] = explained_variance(
|
||||
train_batch[Postprocessing.VALUE_TARGETS], model.value_function())
|
||||
model.tower_stats["mean_entropy"] = mean_entropy
|
||||
model.tower_stats["mean_kl_loss"] = mean_kl_loss
|
||||
|
||||
return total_loss
|
||||
|
||||
def _value(self, **input_dict):
|
||||
# When doing GAE, we need the value function estimate on the
|
||||
# observation.
|
||||
if self.config["use_gae"]:
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
model_out, _ = self.model(input_dict)
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0].item()
|
||||
# When not doing GAE, we do not require the value function's output.
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def update_kl(self, sampled_kl):
|
||||
# Update the current KL value based on the recently measured value.
|
||||
|
@ -199,75 +177,56 @@ class KLCoeffMixin:
|
|||
# Return the current KL value.
|
||||
return self.kl_coeff
|
||||
|
||||
# TODO: Make this an event-style subscription (e.g.:
|
||||
# "after_actions_computed").
|
||||
@override(TorchPolicy)
|
||||
def extra_action_out(self, input_dict, state_batches, model, action_dist):
|
||||
# Return value function outputs. VF estimates will hence be added to
|
||||
# the SampleBatches produced by the sampler(s) to generate the train
|
||||
# batches going into the loss function.
|
||||
return {
|
||||
SampleBatch.VF_PREDS: model.value_function(),
|
||||
}
|
||||
|
||||
class ValueNetworkMixin:
|
||||
"""Assigns the `_value()` method to the PPOPolicy.
|
||||
# TODO: Make this an event-style subscription (e.g.:
|
||||
# "after_gradients_computed").
|
||||
@override(TorchPolicy)
|
||||
def extra_grad_process(self, local_optimizer, loss):
|
||||
return apply_grad_clipping(self, local_optimizer, loss)
|
||||
|
||||
This way, Policy can call `_value()` to get the current VF estimate on a
|
||||
single(!) observation (as done in `postprocess_trajectory_fn`).
|
||||
Note: When doing this, an actual forward pass is being performed.
|
||||
This is different from only calling `model.value_function()`, where
|
||||
the result of the most recent forward pass is being used to return an
|
||||
already calculated tensor.
|
||||
"""
|
||||
# TODO: Make this an event-style subscription (e.g.:
|
||||
# "after_losses_computed").
|
||||
@override(TorchPolicy)
|
||||
def extra_grad_info(self,
|
||||
train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
return convert_to_numpy({
|
||||
"cur_kl_coeff": self.kl_coeff,
|
||||
"cur_lr": self.cur_lr,
|
||||
"total_loss": torch.mean(
|
||||
torch.stack(self.get_tower_stats("total_loss"))),
|
||||
"policy_loss": torch.mean(
|
||||
torch.stack(self.get_tower_stats("mean_policy_loss"))),
|
||||
"vf_loss": torch.mean(
|
||||
torch.stack(self.get_tower_stats("mean_vf_loss"))),
|
||||
"vf_explained_var": torch.mean(
|
||||
torch.stack(self.get_tower_stats("vf_explained_var"))),
|
||||
"kl": torch.mean(
|
||||
torch.stack(self.get_tower_stats("mean_kl_loss"))),
|
||||
"entropy": torch.mean(
|
||||
torch.stack(self.get_tower_stats("mean_entropy"))),
|
||||
"entropy_coeff": self.entropy_coeff,
|
||||
})
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
# When doing GAE, we need the value function estimate on the
|
||||
# observation.
|
||||
if config["use_gae"]:
|
||||
# Input dict is provided to us automatically via the Model's
|
||||
# requirements. It's a single-timestep (last one in trajectory)
|
||||
# input_dict.
|
||||
|
||||
def value(**input_dict):
|
||||
input_dict = SampleBatch(input_dict)
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
model_out, _ = self.model(input_dict)
|
||||
# [0] = remove the batch dim.
|
||||
return self.model.value_function()[0].item()
|
||||
|
||||
# When not doing GAE, we do not require the value function's output.
|
||||
else:
|
||||
|
||||
def value(*args, **kwargs):
|
||||
return 0.0
|
||||
|
||||
self._value = value
|
||||
|
||||
|
||||
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Call all mixin classes' constructors before PPOPolicy initialization.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
KLCoeffMixin.__init__(policy, config)
|
||||
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
|
||||
config["entropy_coeff_schedule"])
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
# Build a child class of `TorchPolicy`, given the custom functions defined
|
||||
# above.
|
||||
PPOTorchPolicy = build_policy_class(
|
||||
name="PPOTorchPolicy",
|
||||
framework="torch",
|
||||
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
|
||||
loss_fn=ppo_surrogate_loss,
|
||||
stats_fn=kl_and_loss_stats,
|
||||
extra_action_out_fn=vf_preds_fetches,
|
||||
postprocess_fn=compute_gae_for_sample_batch,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
before_init=setup_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[
|
||||
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
|
||||
ValueNetworkMixin
|
||||
],
|
||||
)
|
||||
# TODO: Make lr-schedule and entropy-schedule Plugin-style functionalities
|
||||
# that can be added (via the config) to any Trainer/Policy.
|
||||
@override(TorchPolicy)
|
||||
def on_global_var_update(self, global_vars):
|
||||
super().on_global_var_update(global_vars)
|
||||
if self._lr_schedule:
|
||||
self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
|
||||
for opt in self._optimizers:
|
||||
for p in opt.param_groups:
|
||||
p["lr"] = self.cur_lr
|
||||
if self._entropy_coeff_schedule is not None:
|
||||
self.entropy_coeff = self._entropy_coeff_schedule.value(
|
||||
global_vars["timestep"])
|
||||
|
|
|
@ -7,8 +7,7 @@ from ray.rllib.agents.callbacks import DefaultCallbacks
|
|||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import ppo_surrogate_loss as \
|
||||
ppo_surrogate_loss_tf
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import ppo_surrogate_loss as \
|
||||
ppo_surrogate_loss_torch
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
||||
Postprocessing
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||
|
@ -272,8 +271,8 @@ class TestPPO(unittest.TestCase):
|
|||
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
|
||||
train_batch)
|
||||
elif fw == "torch":
|
||||
ppo_surrogate_loss_torch(policy, policy.model,
|
||||
TorchCategorical, train_batch)
|
||||
PPOTorchPolicy.loss(policy, policy.model, policy.dist_class,
|
||||
train_batch)
|
||||
|
||||
vars = policy.model.variables() if fw != "torch" else \
|
||||
list(policy.model.parameters())
|
||||
|
|
|
@ -123,7 +123,7 @@ class AlphaZeroPolicy(TorchPolicy):
|
|||
sample_batch["t"])
|
||||
return sample_batch
|
||||
|
||||
@override(Policy)
|
||||
@override(TorchPolicy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
|
|
|
@ -23,9 +23,9 @@ def adjust_nstep(n_step: int, gamma: float, batch: SampleBatch) -> None:
|
|||
n is truncated to fit in the traj length.
|
||||
|
||||
Args:
|
||||
n_step (int): The number of steps to look ahead and adjust.
|
||||
gamma (float): The discount factor.
|
||||
batch (SampleBatch): The SampleBatch to adjust (in place).
|
||||
n_step: The number of steps to look ahead and adjust.
|
||||
gamma: The discount factor.
|
||||
batch: The SampleBatch to adjust (in place).
|
||||
|
||||
Examples:
|
||||
n-step=3
|
||||
|
@ -73,21 +73,19 @@ def compute_advantages(rollout: SampleBatch,
|
|||
lambda_: float = 1.0,
|
||||
use_gae: bool = True,
|
||||
use_critic: bool = True):
|
||||
"""
|
||||
Given a rollout, compute its value targets and the advantages.
|
||||
"""Given a rollout, compute its value targets and the advantages.
|
||||
|
||||
Args:
|
||||
rollout (SampleBatch): SampleBatch of a single trajectory.
|
||||
last_r (float): Value estimation for last observation.
|
||||
gamma (float): Discount factor.
|
||||
lambda_ (float): Parameter for GAE.
|
||||
use_gae (bool): Using Generalized Advantage Estimation.
|
||||
use_critic (bool): Whether to use critic (value estimates). Setting
|
||||
rollout: SampleBatch of a single trajectory.
|
||||
last_r: Value estimation for last observation.
|
||||
gamma: Discount factor.
|
||||
lambda_: Parameter for GAE.
|
||||
use_gae: Using Generalized Advantage Estimation.
|
||||
use_critic: Whether to use critic (value estimates). Setting
|
||||
this to False will use 0 as baseline.
|
||||
|
||||
Returns:
|
||||
SampleBatch (SampleBatch): Object with experience from rollout and
|
||||
processed rewards.
|
||||
SampleBatch with experience from rollout and processed rewards.
|
||||
"""
|
||||
|
||||
assert SampleBatch.VF_PREDS in rollout or not use_critic, \
|
||||
|
@ -147,17 +145,16 @@ def compute_gae_for_sample_batch(
|
|||
New columns can be added to sample_batch and existing ones may be altered.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy used to generate the trajectory
|
||||
(`sample_batch`)
|
||||
sample_batch (SampleBatch): The SampleBatch to postprocess.
|
||||
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
|
||||
dict of AgentIDs mapping to other agents' trajectory data (from the
|
||||
same episode). NOTE: The other agents use the same policy.
|
||||
episode (Optional[Episode]): Optional multi-agent episode
|
||||
object in which the agents operated.
|
||||
policy: The Policy used to generate the trajectory (`sample_batch`)
|
||||
sample_batch: The SampleBatch to postprocess.
|
||||
other_agent_batches: Optional dict of AgentIDs mapping to other
|
||||
agents' trajectory data (from the same episode).
|
||||
NOTE: The other agents use the same policy.
|
||||
episode: Optional multi-agent episode object in which the agents
|
||||
operated.
|
||||
|
||||
Returns:
|
||||
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
||||
The postprocessed, modified SampleBatch (or a new one).
|
||||
"""
|
||||
|
||||
# Trajectory is actually complete -> last r=0.0.
|
||||
|
@ -193,11 +190,11 @@ def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
|
|||
reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
|
||||
|
||||
Args:
|
||||
gamma (float): The discount factor gamma.
|
||||
gamma: The discount factor gamma.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The sequence containing the discounted cumulative sums
|
||||
for each individual reward in `x` till the end of the trajectory.
|
||||
The sequence containing the discounted cumulative sums
|
||||
for each individual reward in `x` till the end of the trajectory.
|
||||
|
||||
Examples:
|
||||
>>> x = np.array([0.0, 1.0, 2.0, 3.0])
|
||||
|
|
|
@ -185,8 +185,10 @@ if __name__ == "__main__":
|
|||
# start with all zeros as state
|
||||
num_transformers = config["model"][
|
||||
"attention_num_transformer_units"]
|
||||
attention_dim = config["model"]["attention_dim"]
|
||||
memory = config["model"]["attention_memory_inference"]
|
||||
init_state = state = [
|
||||
np.zeros([100, 32], np.float32)
|
||||
np.zeros([memory, attention_dim], np.float32)
|
||||
for _ in range(num_transformers)
|
||||
]
|
||||
# run one iteration until done
|
||||
|
|
|
@ -20,11 +20,12 @@ import os
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.maml.maml_torch_policy import KLCoeffMixin as \
|
||||
TorchKLCoeffMixin
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy, KLCoeffMixin, \
|
||||
ppo_surrogate_loss as tf_loss
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy, \
|
||||
KLCoeffMixin as TorchKLCoeffMixin, ppo_surrogate_loss as torch_loss
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
|
@ -36,6 +37,7 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule, \
|
|||
EntropyCoeffSchedule
|
||||
from ray.rllib.policy.torch_policy import LearningRateSchedule as TorchLR, \
|
||||
EntropyCoeffSchedule as TorchEntropyCoeffSchedule
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
|
||||
|
@ -143,7 +145,8 @@ def centralized_critic_postprocessing(policy,
|
|||
# Copied from PPO but optimizing the central value function.
|
||||
def loss_with_central_critic(policy, model, dist_class, train_batch):
|
||||
CentralizedValueMixin.__init__(policy)
|
||||
func = tf_loss if not policy.config["framework"] == "torch" else torch_loss
|
||||
func = tf_loss if not policy.config["framework"] == "torch" \
|
||||
else PPOTorchPolicy.loss
|
||||
|
||||
vf_saved = model.value_function
|
||||
model.value_function = lambda: policy.model.central_value_function(
|
||||
|
@ -194,15 +197,23 @@ CCPPOTFPolicy = PPOTFPolicy.with_updates(
|
|||
CentralizedValueMixin
|
||||
])
|
||||
|
||||
CCPPOTorchPolicy = PPOTorchPolicy.with_updates(
|
||||
name="CCPPOTorchPolicy",
|
||||
postprocess_fn=centralized_critic_postprocessing,
|
||||
loss_fn=loss_with_central_critic,
|
||||
before_init=setup_torch_mixins,
|
||||
mixins=[
|
||||
TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin,
|
||||
CentralizedValueMixin
|
||||
])
|
||||
|
||||
class CCPPOTorchPolicy(PPOTorchPolicy):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
super().__init__(observation_space, action_space, config)
|
||||
self.compute_central_vf = self.model.central_value_function
|
||||
|
||||
@override(PPOTorchPolicy)
|
||||
def loss(self, model, dist_class, train_batch):
|
||||
return loss_with_central_critic(self, model, dist_class, train_batch)
|
||||
|
||||
@override(PPOTorchPolicy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
return centralized_critic_postprocessing(self, sample_batch,
|
||||
other_agent_batches, episode)
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
|
|
|
@ -326,7 +326,18 @@ def build_eager_tf_policy(
|
|||
self._re_trace_counter = 0
|
||||
|
||||
self._loss_initialized = False
|
||||
self._loss = loss_fn
|
||||
# To ensure backward compatibility:
|
||||
# Old way: If `loss` provided here, use as-is (as a function).
|
||||
if loss_fn is not None:
|
||||
self._loss = loss_fn
|
||||
# New way: Convert the overridden `self.loss` into a plain
|
||||
# function, so it can be called the same way as `loss` would
|
||||
# be, ensuring backward compatibility.
|
||||
elif self.loss.__func__.__qualname__ != "Policy.loss":
|
||||
self._loss = self.loss.__func__
|
||||
# `loss` not provided nor overridden from Policy -> Set to None.
|
||||
else:
|
||||
self._loss = None
|
||||
|
||||
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
|
||||
callable(get_batch_divisibility_req) else \
|
||||
|
@ -828,7 +839,7 @@ def build_eager_tf_policy(
|
|||
# Calculate the loss(es) inside a tf GradientTape.
|
||||
with tf.GradientTape(persistent=compute_gradients_fn is not None) \
|
||||
as tape:
|
||||
losses = loss_fn(self, self.model, self.dist_class, samples)
|
||||
losses = self._loss(self, self.model, self.dist_class, samples)
|
||||
losses = force_list(losses)
|
||||
|
||||
# User provided a compute_gradients_fn.
|
||||
|
|
|
@ -7,10 +7,13 @@ import numpy as np
|
|||
import tree # pip install dm_tree
|
||||
from typing import Dict, List, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, \
|
||||
OverrideToImplementCustomLogic
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
@ -88,8 +91,7 @@ class Policy(metaclass=ABCMeta):
|
|||
"""Initializes a Policy instance.
|
||||
|
||||
Args:
|
||||
observation_space: Observation space of the
|
||||
policy.
|
||||
observation_space: Observation space of the policy.
|
||||
action_space: Action space of the policy.
|
||||
config: A complete Trainer/Policy config dict. For the default
|
||||
config keys and values, see rllib/trainer/trainer.py.
|
||||
|
@ -404,6 +406,25 @@ class Policy(metaclass=ABCMeta):
|
|||
# The default implementation just returns the same, unaltered batch.
|
||||
return sample_batch
|
||||
|
||||
@ExperimentalAPI
|
||||
@OverrideToImplementCustomLogic
|
||||
def loss(self, model: ModelV2, dist_class: ActionDistribution,
|
||||
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
|
||||
"""Loss function for this Policy.
|
||||
|
||||
Override this method in order to implement custom loss computations.
|
||||
|
||||
Args:
|
||||
model: The model to calculate the loss(es).
|
||||
dist_class: The action distribution class to sample actions
|
||||
from the model's outputs.
|
||||
train_batch: The input batch on which to calculate the loss.
|
||||
|
||||
Returns:
|
||||
Either a single loss tensor or a list of loss tensors.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]:
|
||||
"""Perform one learning update, given `samples`.
|
||||
|
@ -620,8 +641,8 @@ class Policy(metaclass=ABCMeta):
|
|||
"""Called on an update to global vars.
|
||||
|
||||
Args:
|
||||
global_vars (Dict[str, TensorType]): Global variables by str key,
|
||||
broadcast from the driver.
|
||||
global_vars: Global variables by str key, broadcast from the
|
||||
driver.
|
||||
"""
|
||||
# Store the current global time step (sum over all policies' sample
|
||||
# steps).
|
||||
|
@ -632,7 +653,7 @@ class Policy(metaclass=ABCMeta):
|
|||
"""Export Policy checkpoint to local directory.
|
||||
|
||||
Args:
|
||||
export_dir (str): Local writable directory.
|
||||
export_dir: Local writable directory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -646,8 +667,8 @@ class Policy(metaclass=ABCMeta):
|
|||
implementations for more details.
|
||||
|
||||
Args:
|
||||
export_dir (str): Local writable directory.
|
||||
onnx (int): If given, will export model in ONNX format. The
|
||||
export_dir: Local writable directory.
|
||||
onnx: If given, will export model in ONNX format. The
|
||||
value of this parameter set the ONNX OpSet version to use.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -990,6 +1011,10 @@ class Policy(metaclass=ABCMeta):
|
|||
vr["state_out_{}".format(i)] = ViewRequirement(
|
||||
space=space, used_for_training=True)
|
||||
|
||||
@DeveloperAPI
|
||||
def __repr__(self):
|
||||
return type(self).__name__
|
||||
|
||||
@Deprecated(new="get_exploration_state", error=False)
|
||||
def get_exploration_info(self) -> Dict[str, TensorType]:
|
||||
return self.get_exploration_state()
|
||||
|
|
|
@ -25,7 +25,7 @@ jax, _ = try_import_jax()
|
|||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
# TODO: (sven) Unify this with `build_tf_policy` as well.
|
||||
# TODO: Deprecate in favor of directly sub-classing from TorchPolicy.
|
||||
@DeveloperAPI
|
||||
def build_policy_class(
|
||||
name: str,
|
||||
|
|
|
@ -8,10 +8,11 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import tree # pip install dm_tree
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, \
|
||||
TYPE_CHECKING
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, \
|
||||
Union, TYPE_CHECKING
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
|
@ -19,7 +20,7 @@ from ray.rllib.policy.policy import Policy
|
|||
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import force_list, NullContextManager
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
|
@ -27,8 +28,8 @@ from ray.rllib.utils.schedules import PiecewiseSchedule
|
|||
from ray.rllib.utils.spaces.space_utils import normalize_action
|
||||
from ray.rllib.utils.threading import with_lock
|
||||
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
||||
from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \
|
||||
TensorStructType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import GradInfoDict, ModelGradients, \
|
||||
ModelWeights, TensorType, TensorStructType, TrainerConfigDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.evaluation import Episode # noqa
|
||||
|
@ -49,11 +50,12 @@ class TorchPolicy(Policy):
|
|||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
*,
|
||||
model: ModelV2,
|
||||
loss: Callable[[
|
||||
model: Optional[TorchModelV2] = None,
|
||||
loss: Optional[Callable[[
|
||||
Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
|
||||
], Union[TensorType, List[TensorType]]],
|
||||
action_distribution_class: Type[TorchDistributionWrapper],
|
||||
], Union[TensorType, List[TensorType]]]] = None,
|
||||
action_distribution_class: Optional[Type[
|
||||
TorchDistributionWrapper]] = None,
|
||||
action_sampler_fn: Optional[Callable[[
|
||||
TensorType, List[TensorType]
|
||||
], Tuple[TensorType, TensorType]]] = None,
|
||||
|
@ -113,7 +115,7 @@ class TorchPolicy(Policy):
|
|||
get_batch_divisibility_req: Optional callable that returns the
|
||||
divisibility requirement for sample batches given the Policy.
|
||||
"""
|
||||
self.framework = "torch"
|
||||
self.framework = config["framework"] = "torch"
|
||||
super().__init__(observation_space, action_space, config)
|
||||
|
||||
# Create multi-GPU model towers, if necessary.
|
||||
|
@ -128,6 +130,19 @@ class TorchPolicy(Policy):
|
|||
# - In case of just one device (1 (fake or real) GPU or 1 CPU), no
|
||||
# parallelization will be done.
|
||||
|
||||
# If no Model is provided, build a default one here.
|
||||
if model is None:
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"], framework=self.framework)
|
||||
model = ModelCatalog.get_model_v2(
|
||||
obs_space=self.observation_space,
|
||||
action_space=self.action_space,
|
||||
num_outputs=logit_dim,
|
||||
model_config=self.config["model"],
|
||||
framework=self.framework)
|
||||
if action_distribution_class is None:
|
||||
action_distribution_class = dist_class
|
||||
|
||||
# Get devices to build the graph on.
|
||||
worker_idx = self.config.get("worker_index", 0)
|
||||
if not config["_fake_gpus"] and \
|
||||
|
@ -213,7 +228,18 @@ class TorchPolicy(Policy):
|
|||
|
||||
self.exploration = self._create_exploration()
|
||||
self.unwrapped_model = model # used to support DistributedDataParallel
|
||||
self._loss = loss
|
||||
# To ensure backward compatibility:
|
||||
# Old way: If `loss` provided here, use as-is (as a function).
|
||||
if loss is not None:
|
||||
self._loss = loss
|
||||
# New way: Convert the overridden `self.loss` into a plain function,
|
||||
# so it can be called the same way as `loss` would be, ensuring
|
||||
# backward compatibility.
|
||||
elif self.loss.__func__.__qualname__ != "Policy.loss":
|
||||
self._loss = self.loss.__func__
|
||||
# `loss` not provided nor overridden from Policy -> Set to None.
|
||||
else:
|
||||
self._loss = None
|
||||
self._optimizers = force_list(self.optimizer())
|
||||
# Store, which params (by index within the model's list of
|
||||
# parameters) should be updated per optimizer.
|
||||
|
@ -616,11 +642,11 @@ class TorchPolicy(Policy):
|
|||
|
||||
Returns:
|
||||
The list of stats tensor (structs) of all towers, copied to this
|
||||
Policy's device.
|
||||
Policy's device.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the `stats_name` cannot be found in any one
|
||||
of the tower's `tower_stats` dicts.
|
||||
of the tower's `tower_stats` dicts.
|
||||
"""
|
||||
data = []
|
||||
for tower in self.model_gpu_towers:
|
||||
|
@ -697,7 +723,7 @@ class TorchPolicy(Policy):
|
|||
|
||||
@DeveloperAPI
|
||||
def extra_grad_process(self, optimizer: "torch.optim.Optimizer",
|
||||
loss: TensorType):
|
||||
loss: TensorType) -> Dict[str, TensorType]:
|
||||
"""Called after each optimizer.zero_grad() + loss.backward() call.
|
||||
|
||||
Called for each self._optimizers/loss-value pair.
|
||||
|
@ -705,22 +731,21 @@ class TorchPolicy(Policy):
|
|||
E.g. for gradient clipping.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): A torch optimizer object.
|
||||
loss (TensorType): The loss tensor associated with the optimizer.
|
||||
optimizer: A torch optimizer object.
|
||||
loss: The loss tensor associated with the optimizer.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: An dict with information on the gradient
|
||||
processing step.
|
||||
An dict with information on the gradient processing step.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_fetches(self) -> Dict[str, any]:
|
||||
def extra_compute_grad_fetches(self) -> Dict[str, Any]:
|
||||
"""Extra values to fetch and return from compute_gradients().
|
||||
|
||||
Returns:
|
||||
Dict[str, any]: Extra fetch dict to be added to the fetch dict
|
||||
of the compute_gradients call.
|
||||
Extra fetch dict to be added to the fetch dict of the
|
||||
`compute_gradients` call.
|
||||
"""
|
||||
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
|
||||
|
||||
|
@ -732,15 +757,15 @@ class TorchPolicy(Policy):
|
|||
"""Returns dict of extra info to include in experience batch.
|
||||
|
||||
Args:
|
||||
input_dict (Dict[str, TensorType]): Dict of model input tensors.
|
||||
state_batches (List[TensorType]): List of state tensors.
|
||||
model (TorchModelV2): Reference to the model object.
|
||||
action_dist (TorchDistributionWrapper): Torch action dist object
|
||||
input_dict: Dict of model input tensors.
|
||||
state_batches: List of state tensors.
|
||||
model: Reference to the model object.
|
||||
action_dist: Torch action dist object
|
||||
to get log-probs (e.g. for already sampled actions).
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: Extra outputs to return in a
|
||||
compute_actions() call (3rd return value).
|
||||
Extra outputs to return in a `compute_actions_from_input_dict()`
|
||||
call (3rd return value).
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
@ -750,12 +775,11 @@ class TorchPolicy(Policy):
|
|||
"""Return dict of extra grad info.
|
||||
|
||||
Args:
|
||||
train_batch (SampleBatch): The training batch for which to produce
|
||||
train_batch: The training batch for which to produce
|
||||
extra grad info for.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The info dict carrying grad info per str
|
||||
key.
|
||||
The info dict carrying grad info per str key.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
@ -766,8 +790,7 @@ class TorchPolicy(Policy):
|
|||
"""Custom the local PyTorch optimizer(s) to use.
|
||||
|
||||
Returns:
|
||||
Union[List[torch.optim.Optimizer], torch.optim.Optimizer]:
|
||||
The local PyTorch optimizer(s) to use for this Policy.
|
||||
The local PyTorch optimizer(s) to use for this Policy.
|
||||
"""
|
||||
if hasattr(self, "config"):
|
||||
optimizers = [
|
||||
|
@ -789,7 +812,9 @@ class TorchPolicy(Policy):
|
|||
Creates a TorchScript model and saves it.
|
||||
|
||||
Args:
|
||||
export_dir (str): Local writable directory or filename.
|
||||
export_dir: Local writable directory or filename.
|
||||
onnx: If given, will export model in ONNX format. The
|
||||
value of this parameter set the ONNX OpSet version to use.
|
||||
"""
|
||||
self._lazy_tensor_dict(self._dummy_batch)
|
||||
# Provide dummy state inputs if not an RNN (torch cannot jit with
|
||||
|
@ -852,8 +877,7 @@ class TorchPolicy(Policy):
|
|||
"""Shared forward pass logic (w/ and w/o trajectory view API).
|
||||
|
||||
Returns:
|
||||
Tuple:
|
||||
- actions, state_out, extra_fetches, logp.
|
||||
A tuple consisting of a) actions, b) state_out, c) extra_fetches.
|
||||
"""
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
|
@ -956,7 +980,9 @@ class TorchPolicy(Policy):
|
|||
convert_to_torch_tensor, device=device or self.device))
|
||||
return postprocessed_batch
|
||||
|
||||
def _multi_gpu_parallel_grad_calc(self, sample_batches):
|
||||
def _multi_gpu_parallel_grad_calc(
|
||||
self, sample_batches: List[SampleBatch]
|
||||
) -> List[Tuple[List[TensorType], GradInfoDict]]:
|
||||
"""Performs a parallelized loss and gradient calculation over the batch.
|
||||
|
||||
Splits up the given train batch into n shards (n=number of this
|
||||
|
@ -965,12 +991,12 @@ class TorchPolicy(Policy):
|
|||
(self.model_gpu_towers). Then returns each tower's outputs.
|
||||
|
||||
Args:
|
||||
sample_batches (List[SampleBatch]): A list of SampleBatch shards to
|
||||
sample_batches: A list of SampleBatch shards to
|
||||
calculate loss and gradients for.
|
||||
|
||||
Returns:
|
||||
List[Tuple[List[TensorType], StatsDict]]: A list (one item per
|
||||
device) of 2-tuples with 1) gradient list and 2) stats dict.
|
||||
A list (one item per device) of 2-tuples, each with 1) gradient
|
||||
list and 2) grad info dict.
|
||||
"""
|
||||
assert len(self.model_gpu_towers) == len(sample_batches)
|
||||
lock = threading.Lock()
|
||||
|
|
|
@ -65,5 +65,35 @@ def ExperimentalAPI(obj):
|
|||
return obj
|
||||
|
||||
|
||||
def OverrideToImplementCustomLogic(obj):
|
||||
"""Users should override this in their sub-classes to implement custom logic.
|
||||
|
||||
Used in Trainer and Policy to tag methods that need overriding, e.g.
|
||||
`Policy.loss()`.
|
||||
"""
|
||||
return obj
|
||||
|
||||
|
||||
def OverrideToImplementCustomLogic_CallToSuperRecommended(obj):
|
||||
"""Users should override this in their sub-classes to implement custom logic.
|
||||
|
||||
Thereby, it is recommended (but not required) to call the super-class'
|
||||
corresponding method.
|
||||
|
||||
Used in Trainer and Policy to tag methods that need overriding, but the
|
||||
super class' method should still be called, e.g.
|
||||
`Trainer.setup()`.
|
||||
|
||||
Examples:
|
||||
>>> @overrides(Trainable)
|
||||
... @OverrideToImplementCustomLogic_CallToSuperRecommended
|
||||
... def setup(self, config):
|
||||
... # implement custom setup logic here ...
|
||||
... super().setup(config)
|
||||
... # ... or here (after having called super()'s setup method.
|
||||
"""
|
||||
return obj
|
||||
|
||||
|
||||
# Backward compatibility.
|
||||
Deprecated = Deprecated
|
||||
|
|
Loading…
Add table
Reference in a new issue