2022-05-17 08:16:08 -07:00
|
|
|
from typing import Dict, List, Union
|
|
|
|
|
|
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.policy.torch_policy import TorchPolicy
|
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI, override
|
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
from ray.rllib.utils.schedules import PiecewiseSchedule
|
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
TensorType,
|
|
|
|
)
|
|
|
|
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch)
|
|
|
|
# and for all possible hyperparams, not just lr.
|
|
|
|
@DeveloperAPI
|
|
|
|
class LearningRateSchedule:
|
|
|
|
"""Mixin for TorchPolicy that adds a learning rate schedule."""
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def __init__(self, lr, lr_schedule):
|
|
|
|
self._lr_schedule = None
|
|
|
|
if lr_schedule is None:
|
|
|
|
self.cur_lr = lr
|
|
|
|
else:
|
|
|
|
self._lr_schedule = PiecewiseSchedule(
|
|
|
|
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None
|
|
|
|
)
|
|
|
|
self.cur_lr = self._lr_schedule.value(0)
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def on_global_var_update(self, global_vars):
|
|
|
|
super().on_global_var_update(global_vars)
|
|
|
|
if self._lr_schedule:
|
|
|
|
self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
|
|
|
|
for opt in self._optimizers:
|
|
|
|
for p in opt.param_groups:
|
|
|
|
p["lr"] = self.cur_lr
|
|
|
|
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
class EntropyCoeffSchedule:
|
|
|
|
"""Mixin for TorchPolicy that adds entropy coeff decay."""
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def __init__(self, entropy_coeff, entropy_coeff_schedule):
|
|
|
|
self._entropy_coeff_schedule = None
|
|
|
|
if entropy_coeff_schedule is None:
|
|
|
|
self.entropy_coeff = entropy_coeff
|
|
|
|
else:
|
|
|
|
# Allows for custom schedule similar to lr_schedule format
|
|
|
|
if isinstance(entropy_coeff_schedule, list):
|
|
|
|
self._entropy_coeff_schedule = PiecewiseSchedule(
|
|
|
|
entropy_coeff_schedule,
|
|
|
|
outside_value=entropy_coeff_schedule[-1][-1],
|
|
|
|
framework=None,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# Implements previous version but enforces outside_value
|
|
|
|
self._entropy_coeff_schedule = PiecewiseSchedule(
|
|
|
|
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
|
|
|
|
outside_value=0.0,
|
|
|
|
framework=None,
|
|
|
|
)
|
|
|
|
self.entropy_coeff = self._entropy_coeff_schedule.value(0)
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def on_global_var_update(self, global_vars):
|
|
|
|
super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
|
|
|
|
if self._entropy_coeff_schedule is not None:
|
|
|
|
self.entropy_coeff = self._entropy_coeff_schedule.value(
|
|
|
|
global_vars["timestep"]
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class KLCoeffMixin:
|
|
|
|
"""Assigns the `update_kl()` method to a TorchPolicy.
|
|
|
|
|
2022-06-20 15:54:00 +02:00
|
|
|
This is used by Algorithms to update the KL coefficient
|
2022-05-17 08:16:08 -07:00
|
|
|
after each learning step based on `config.kl_target` and
|
|
|
|
the measured KL value (from the train_batch).
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
# The current KL value (as python float).
|
|
|
|
self.kl_coeff = config["kl_coeff"]
|
|
|
|
# Constant target value.
|
|
|
|
self.kl_target = config["kl_target"]
|
|
|
|
|
|
|
|
def update_kl(self, sampled_kl):
|
|
|
|
# Update the current KL value based on the recently measured value.
|
|
|
|
if sampled_kl > 2.0 * self.kl_target:
|
|
|
|
self.kl_coeff *= 1.5
|
|
|
|
elif sampled_kl < 0.5 * self.kl_target:
|
|
|
|
self.kl_coeff *= 0.5
|
|
|
|
# Return the current KL value.
|
|
|
|
return self.kl_coeff
|
|
|
|
|
|
|
|
@override(TorchPolicy)
|
|
|
|
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
|
|
|
state = super().get_state()
|
|
|
|
# Add current kl-coeff value.
|
|
|
|
state["current_kl_coeff"] = self.kl_coeff
|
|
|
|
return state
|
|
|
|
|
|
|
|
@override(TorchPolicy)
|
|
|
|
def set_state(self, state: dict) -> None:
|
|
|
|
# Set current kl-coeff value first.
|
|
|
|
self.kl_coeff = state.pop("current_kl_coeff", self.config["kl_coeff"])
|
|
|
|
# Call super's set_state with rest of the state dict.
|
|
|
|
super().set_state(state)
|
|
|
|
|
|
|
|
|
|
|
|
class ValueNetworkMixin:
|
|
|
|
"""Assigns the `_value()` method to a TorchPolicy.
|
|
|
|
|
|
|
|
This way, Policy can call `_value()` to get the current VF estimate on a
|
|
|
|
single(!) observation (as done in `postprocess_trajectory_fn`).
|
|
|
|
Note: When doing this, an actual forward pass is being performed.
|
|
|
|
This is different from only calling `model.value_function()`, where
|
|
|
|
the result of the most recent forward pass is being used to return an
|
|
|
|
already calculated tensor.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
# When doing GAE, we need the value function estimate on the
|
|
|
|
# observation.
|
|
|
|
if config["use_gae"]:
|
|
|
|
# Input dict is provided to us automatically via the Model's
|
|
|
|
# requirements. It's a single-timestep (last one in trajectory)
|
|
|
|
# input_dict.
|
|
|
|
|
|
|
|
def value(**input_dict):
|
|
|
|
input_dict = SampleBatch(input_dict)
|
|
|
|
input_dict = self._lazy_tensor_dict(input_dict)
|
|
|
|
model_out, _ = self.model(input_dict)
|
|
|
|
# [0] = remove the batch dim.
|
|
|
|
return self.model.value_function()[0].item()
|
|
|
|
|
|
|
|
# When not doing GAE, we do not require the value function's output.
|
|
|
|
else:
|
|
|
|
|
|
|
|
def value(*args, **kwargs):
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
self._value = value
|
|
|
|
|
|
|
|
def extra_action_out(self, input_dict, state_batches, model, action_dist):
|
|
|
|
"""Defines extra fetches per action computation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_dict (Dict[str, TensorType]): The input dict used for the action
|
|
|
|
computing forward pass.
|
|
|
|
state_batches (List[TensorType]): List of state tensors (empty for
|
|
|
|
non-RNNs).
|
|
|
|
model (ModelV2): The Model object of the Policy.
|
2022-06-01 11:27:54 -07:00
|
|
|
action_dist: The instantiated distribution
|
2022-05-17 08:16:08 -07:00
|
|
|
object, resulting from the model's outputs and the given
|
|
|
|
distribution class.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, TensorType]: Dict with extra tf fetches to perform per
|
|
|
|
action computation.
|
|
|
|
"""
|
|
|
|
# Return value function outputs. VF estimates will hence be added to
|
|
|
|
# the SampleBatches produced by the sampler(s) to generate the train
|
|
|
|
# batches going into the loss function.
|
|
|
|
return {
|
|
|
|
SampleBatch.VF_PREDS: model.value_function(),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class TargetNetworkMixin:
|
|
|
|
"""Assign the `update_target` method to the SimpleQTorchPolicy
|
|
|
|
|
|
|
|
The function is called every `target_network_update_freq` steps by the
|
|
|
|
master learner.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
# Hard initial update from Q-net(s) to target Q-net(s).
|
|
|
|
self.update_target()
|
|
|
|
|
|
|
|
def update_target(self):
|
|
|
|
# Update_target_fn will be called periodically to copy Q network to
|
|
|
|
# target Q networks.
|
|
|
|
state_dict = self.model.state_dict()
|
|
|
|
for target in self.target_models.values():
|
|
|
|
target.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
@override(TorchPolicy)
|
|
|
|
def set_weights(self, weights):
|
|
|
|
# Makes sure that whenever we restore weights for this policy's
|
|
|
|
# model, we sync the target network (from the main model)
|
|
|
|
# at the same time.
|
|
|
|
TorchPolicy.set_weights(self, weights)
|
|
|
|
self.update_target()
|