ray/rllib/policy/torch_mixins.py

199 lines
7.2 KiB
Python

from typing import Dict, List, Union
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.typing import (
TensorType,
)
torch, nn = try_import_torch()
# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch)
# and for all possible hyperparams, not just lr.
@DeveloperAPI
class LearningRateSchedule:
"""Mixin for TorchPolicy that adds a learning rate schedule."""
@DeveloperAPI
def __init__(self, lr, lr_schedule):
self._lr_schedule = None
if lr_schedule is None:
self.cur_lr = lr
else:
self._lr_schedule = PiecewiseSchedule(
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None
)
self.cur_lr = self._lr_schedule.value(0)
@override(Policy)
def on_global_var_update(self, global_vars):
super().on_global_var_update(global_vars)
if self._lr_schedule:
self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
for opt in self._optimizers:
for p in opt.param_groups:
p["lr"] = self.cur_lr
@DeveloperAPI
class EntropyCoeffSchedule:
"""Mixin for TorchPolicy that adds entropy coeff decay."""
@DeveloperAPI
def __init__(self, entropy_coeff, entropy_coeff_schedule):
self._entropy_coeff_schedule = None
if entropy_coeff_schedule is None:
self.entropy_coeff = entropy_coeff
else:
# Allows for custom schedule similar to lr_schedule format
if isinstance(entropy_coeff_schedule, list):
self._entropy_coeff_schedule = PiecewiseSchedule(
entropy_coeff_schedule,
outside_value=entropy_coeff_schedule[-1][-1],
framework=None,
)
else:
# Implements previous version but enforces outside_value
self._entropy_coeff_schedule = PiecewiseSchedule(
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
outside_value=0.0,
framework=None,
)
self.entropy_coeff = self._entropy_coeff_schedule.value(0)
@override(Policy)
def on_global_var_update(self, global_vars):
super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
if self._entropy_coeff_schedule is not None:
self.entropy_coeff = self._entropy_coeff_schedule.value(
global_vars["timestep"]
)
class KLCoeffMixin:
"""Assigns the `update_kl()` method to a TorchPolicy.
This is used by Algorithms to update the KL coefficient
after each learning step based on `config.kl_target` and
the measured KL value (from the train_batch).
"""
def __init__(self, config):
# The current KL value (as python float).
self.kl_coeff = config["kl_coeff"]
# Constant target value.
self.kl_target = config["kl_target"]
def update_kl(self, sampled_kl):
# Update the current KL value based on the recently measured value.
if sampled_kl > 2.0 * self.kl_target:
self.kl_coeff *= 1.5
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeff *= 0.5
# Return the current KL value.
return self.kl_coeff
@override(TorchPolicy)
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
state = super().get_state()
# Add current kl-coeff value.
state["current_kl_coeff"] = self.kl_coeff
return state
@override(TorchPolicy)
def set_state(self, state: dict) -> None:
# Set current kl-coeff value first.
self.kl_coeff = state.pop("current_kl_coeff", self.config["kl_coeff"])
# Call super's set_state with rest of the state dict.
super().set_state(state)
class ValueNetworkMixin:
"""Assigns the `_value()` method to a TorchPolicy.
This way, Policy can call `_value()` to get the current VF estimate on a
single(!) observation (as done in `postprocess_trajectory_fn`).
Note: When doing this, an actual forward pass is being performed.
This is different from only calling `model.value_function()`, where
the result of the most recent forward pass is being used to return an
already calculated tensor.
"""
def __init__(self, config):
# When doing GAE, we need the value function estimate on the
# observation.
if config["use_gae"]:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
def value(**input_dict):
input_dict = SampleBatch(input_dict)
input_dict = self._lazy_tensor_dict(input_dict)
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0].item()
# When not doing GAE, we do not require the value function's output.
else:
def value(*args, **kwargs):
return 0.0
self._value = value
def extra_action_out(self, input_dict, state_batches, model, action_dist):
"""Defines extra fetches per action computation.
Args:
input_dict (Dict[str, TensorType]): The input dict used for the action
computing forward pass.
state_batches (List[TensorType]): List of state tensors (empty for
non-RNNs).
model (ModelV2): The Model object of the Policy.
action_dist: The instantiated distribution
object, resulting from the model's outputs and the given
distribution class.
Returns:
Dict[str, TensorType]: Dict with extra tf fetches to perform per
action computation.
"""
# Return value function outputs. VF estimates will hence be added to
# the SampleBatches produced by the sampler(s) to generate the train
# batches going into the loss function.
return {
SampleBatch.VF_PREDS: model.value_function(),
}
class TargetNetworkMixin:
"""Assign the `update_target` method to the SimpleQTorchPolicy
The function is called every `target_network_update_freq` steps by the
master learner.
"""
def __init__(self):
# Hard initial update from Q-net(s) to target Q-net(s).
self.update_target()
def update_target(self):
# Update_target_fn will be called periodically to copy Q network to
# target Q networks.
state_dict = self.model.state_dict()
for target in self.target_models.values():
target.load_state_dict(state_dict)
@override(TorchPolicy)
def set_weights(self, weights):
# Makes sure that whenever we restore weights for this policy's
# model, we sync the target network (from the main model)
# at the same time.
TorchPolicy.set_weights(self, weights)
self.update_target()