2020-09-02 14:03:01 +02:00
|
|
|
"""
|
|
|
|
PyTorch policy class used for APPO.
|
2019-03-29 12:44:23 -07:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
Adapted from VTraceTFPolicy to use the PPO surrogate loss.
|
|
|
|
Keep in sync with changes to VTraceTFPolicy.
|
|
|
|
"""
|
2019-03-29 12:44:23 -07:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
import gym
|
2019-03-29 12:44:23 -07:00
|
|
|
import numpy as np
|
|
|
|
import logging
|
2020-09-02 14:03:01 +02:00
|
|
|
from typing import Type
|
2019-03-29 12:44:23 -07:00
|
|
|
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms.dqn.simple_q_torch_policy import TargetNetworkMixin
|
2020-04-23 09:11:12 +02:00
|
|
|
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
|
|
|
from ray.rllib.agents.impala.vtrace_torch_policy import (
|
|
|
|
make_time_major,
|
|
|
|
choose_optimizer,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-09-02 14:03:01 +02:00
|
|
|
from ray.rllib.agents.ppo.appo_tf_policy import make_appo_model, postprocess_trajectory
|
2019-05-18 00:23:11 -07:00
|
|
|
from ray.rllib.evaluation.postprocessing import Postprocessing
|
2020-09-02 14:03:01 +02:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
|
|
from ray.rllib.models.torch.torch_action_dist import (
|
|
|
|
TorchDistributionWrapper,
|
|
|
|
TorchCategorical,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-09-02 14:03:01 +02:00
|
|
|
from ray.rllib.policy.policy import Policy
|
2020-12-26 20:14:18 -05:00
|
|
|
from ray.rllib.policy.policy_template import build_policy_class
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2022-05-17 08:16:08 -07:00
|
|
|
from ray.rllib.policy.torch_mixins import (
|
|
|
|
EntropyCoeffSchedule,
|
|
|
|
LearningRateSchedule,
|
|
|
|
ValueNetworkMixin,
|
|
|
|
)
|
2020-04-23 09:11:12 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
2021-11-03 10:00:46 +01:00
|
|
|
from ray.rllib.utils.torch_utils import (
|
|
|
|
apply_grad_clipping,
|
|
|
|
explained_variance,
|
|
|
|
global_norm,
|
|
|
|
sequence_mask,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-09-02 14:03:01 +02:00
|
|
|
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
2019-05-10 20:36:18 -07:00
|
|
|
|
2020-04-23 09:11:12 +02:00
|
|
|
torch, nn = try_import_torch()
|
2019-07-29 15:02:32 -07:00
|
|
|
|
2019-03-29 12:44:23 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
def appo_surrogate_loss(
|
|
|
|
policy: Policy,
|
|
|
|
model: ModelV2,
|
|
|
|
dist_class: Type[TorchDistributionWrapper],
|
|
|
|
train_batch: SampleBatch,
|
|
|
|
) -> TensorType:
|
|
|
|
"""Constructs the loss for APPO.
|
2020-04-23 09:11:12 +02:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
With IS modifications and V-trace for Advantage Estimation.
|
2019-07-29 15:02:32 -07:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy to calculate the loss for.
|
|
|
|
model (ModelV2): The Model to calculate the loss for.
|
2021-08-18 17:21:01 +02:00
|
|
|
dist_class (Type[ActionDistribution]): The action distr. class.
|
2020-09-02 14:03:01 +02:00
|
|
|
train_batch (SampleBatch): The training data.
|
2019-03-29 12:44:23 -07:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
Returns:
|
|
|
|
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
|
|
|
of loss tensors.
|
|
|
|
"""
|
2021-08-18 17:21:01 +02:00
|
|
|
target_model = policy.target_models[model]
|
|
|
|
|
2021-10-25 15:00:00 +02:00
|
|
|
model_out, _ = model(train_batch)
|
2019-08-23 02:21:11 -04:00
|
|
|
action_dist = dist_class(model_out, model)
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
if isinstance(policy.action_space, gym.spaces.Discrete):
|
|
|
|
is_multidiscrete = False
|
|
|
|
output_hidden_shape = [policy.action_space.n]
|
|
|
|
elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
|
|
|
is_multidiscrete = True
|
|
|
|
output_hidden_shape = policy.action_space.nvec.astype(np.int32)
|
|
|
|
else:
|
|
|
|
is_multidiscrete = False
|
|
|
|
output_hidden_shape = 1
|
|
|
|
|
2021-08-21 17:05:48 +02:00
|
|
|
def _make_time_major(*args, **kwargs):
|
|
|
|
return make_time_major(
|
|
|
|
policy, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs
|
|
|
|
)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
actions = train_batch[SampleBatch.ACTIONS]
|
|
|
|
dones = train_batch[SampleBatch.DONES]
|
|
|
|
rewards = train_batch[SampleBatch.REWARDS]
|
2020-04-01 09:43:21 +02:00
|
|
|
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
|
2019-07-29 15:02:32 -07:00
|
|
|
|
2021-10-25 15:00:00 +02:00
|
|
|
target_model_out, _ = target_model(train_batch)
|
2020-05-27 16:19:13 +02:00
|
|
|
|
2021-08-18 17:21:01 +02:00
|
|
|
prev_action_dist = dist_class(behaviour_logits, model)
|
|
|
|
values = model.value_function()
|
2020-09-02 14:03:01 +02:00
|
|
|
values_time_major = _make_time_major(values)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2021-11-03 10:01:34 +01:00
|
|
|
drop_last = policy.config["vtrace"] and policy.config["vtrace_drop_last_ts"]
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
if policy.is_recurrent():
|
2021-08-21 17:05:48 +02:00
|
|
|
max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
|
|
|
|
mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
|
2020-04-23 09:11:12 +02:00
|
|
|
mask = torch.reshape(mask, [-1])
|
2021-11-03 10:01:34 +01:00
|
|
|
mask = _make_time_major(mask, drop_last=drop_last)
|
2020-09-02 14:03:01 +02:00
|
|
|
num_valid = torch.sum(mask)
|
|
|
|
|
|
|
|
def reduce_mean_valid(t):
|
2021-08-18 17:21:01 +02:00
|
|
|
return torch.sum(t[mask]) / num_valid
|
2020-09-02 14:03:01 +02:00
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
else:
|
2020-09-02 14:03:01 +02:00
|
|
|
reduce_mean_valid = torch.mean
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
if policy.config["vtrace"]:
|
2021-11-03 10:01:34 +01:00
|
|
|
logger.debug(
|
|
|
|
"Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})"
|
|
|
|
)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
old_policy_behaviour_logits = target_model_out.detach()
|
|
|
|
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
|
|
|
|
|
|
|
|
if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
|
|
|
|
unpacked_behaviour_logits = torch.split(
|
|
|
|
behaviour_logits, list(output_hidden_shape), dim=1
|
|
|
|
)
|
|
|
|
unpacked_old_policy_behaviour_logits = torch.split(
|
|
|
|
old_policy_behaviour_logits, list(output_hidden_shape), dim=1
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
unpacked_behaviour_logits = torch.chunk(
|
|
|
|
behaviour_logits, output_hidden_shape, dim=1
|
|
|
|
)
|
|
|
|
unpacked_old_policy_behaviour_logits = torch.chunk(
|
|
|
|
old_policy_behaviour_logits, output_hidden_shape, dim=1
|
|
|
|
)
|
|
|
|
|
2020-10-27 10:00:24 +01:00
|
|
|
# Prepare actions for loss.
|
2020-04-23 09:11:12 +02:00
|
|
|
loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2020-10-27 10:00:24 +01:00
|
|
|
# Prepare KL for loss.
|
2020-09-02 14:03:01 +02:00
|
|
|
action_kl = _make_time_major(
|
2021-11-03 10:01:34 +01:00
|
|
|
old_policy_action_dist.kl(action_dist), drop_last=drop_last
|
|
|
|
)
|
2019-07-29 15:02:32 -07:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
# Compute vtrace on the CPU for better perf.
|
|
|
|
vtrace_returns = vtrace.multi_from_logits(
|
|
|
|
behaviour_policy_logits=_make_time_major(
|
2021-11-03 10:01:34 +01:00
|
|
|
unpacked_behaviour_logits, drop_last=drop_last
|
|
|
|
),
|
2020-09-02 14:03:01 +02:00
|
|
|
target_policy_logits=_make_time_major(
|
2021-11-03 10:01:34 +01:00
|
|
|
unpacked_old_policy_behaviour_logits, drop_last=drop_last
|
|
|
|
),
|
2020-09-02 14:03:01 +02:00
|
|
|
actions=torch.unbind(
|
2021-11-03 10:01:34 +01:00
|
|
|
_make_time_major(loss_actions, drop_last=drop_last), dim=2
|
|
|
|
),
|
|
|
|
discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float())
|
|
|
|
* policy.config["gamma"],
|
|
|
|
rewards=_make_time_major(rewards, drop_last=drop_last),
|
|
|
|
values=values_time_major[:-1] if drop_last else values_time_major,
|
2020-09-02 14:03:01 +02:00
|
|
|
bootstrap_value=values_time_major[-1],
|
2020-04-23 09:11:12 +02:00
|
|
|
dist_class=TorchCategorical if is_multidiscrete else dist_class,
|
2020-09-02 14:03:01 +02:00
|
|
|
model=model,
|
2019-05-18 00:23:11 -07:00
|
|
|
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
|
|
|
|
clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"],
|
2020-09-02 14:03:01 +02:00
|
|
|
)
|
|
|
|
|
2021-11-03 10:01:34 +01:00
|
|
|
actions_logp = _make_time_major(action_dist.logp(actions), drop_last=drop_last)
|
2020-09-02 14:03:01 +02:00
|
|
|
prev_actions_logp = _make_time_major(
|
2021-11-03 10:01:34 +01:00
|
|
|
prev_action_dist.logp(actions), drop_last=drop_last
|
|
|
|
)
|
2020-09-02 14:03:01 +02:00
|
|
|
old_policy_actions_logp = _make_time_major(
|
2021-11-03 10:01:34 +01:00
|
|
|
old_policy_action_dist.logp(actions), drop_last=drop_last
|
|
|
|
)
|
2020-09-02 14:03:01 +02:00
|
|
|
is_ratio = torch.clamp(
|
|
|
|
torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
|
|
|
|
)
|
|
|
|
logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
|
|
|
|
policy._is_ratio = is_ratio
|
|
|
|
|
2021-08-18 17:21:01 +02:00
|
|
|
advantages = vtrace_returns.pg_advantages.to(logp_ratio.device)
|
2020-09-02 14:03:01 +02:00
|
|
|
surrogate_loss = torch.min(
|
|
|
|
advantages * logp_ratio,
|
|
|
|
advantages
|
|
|
|
* torch.clamp(
|
|
|
|
logp_ratio,
|
|
|
|
1 - policy.config["clip_param"],
|
|
|
|
1 + policy.config["clip_param"],
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2020-09-02 14:03:01 +02:00
|
|
|
)
|
|
|
|
|
2021-10-04 13:29:00 +02:00
|
|
|
mean_kl_loss = reduce_mean_valid(action_kl)
|
2020-09-02 14:03:01 +02:00
|
|
|
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
|
|
|
|
|
|
|
|
# The value function loss.
|
2021-08-18 17:21:01 +02:00
|
|
|
value_targets = vtrace_returns.vs.to(values_time_major.device)
|
2021-11-03 10:01:34 +01:00
|
|
|
if drop_last:
|
|
|
|
delta = values_time_major[:-1] - value_targets
|
|
|
|
else:
|
|
|
|
delta = values_time_major - value_targets
|
2020-09-02 14:03:01 +02:00
|
|
|
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
|
|
|
|
|
|
|
|
# The entropy loss.
|
|
|
|
mean_entropy = reduce_mean_valid(
|
2021-11-03 10:01:34 +01:00
|
|
|
_make_time_major(action_dist.entropy(), drop_last=drop_last)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-09-02 14:03:01 +02:00
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
else:
|
2019-08-23 02:21:11 -04:00
|
|
|
logger.debug("Using PPO surrogate loss (vtrace=False)")
|
2019-07-29 15:02:32 -07:00
|
|
|
|
|
|
|
# Prepare KL for Loss
|
2020-09-02 14:03:01 +02:00
|
|
|
action_kl = _make_time_major(prev_action_dist.kl(action_dist))
|
|
|
|
|
|
|
|
actions_logp = _make_time_major(action_dist.logp(actions))
|
|
|
|
prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
|
|
|
|
logp_ratio = torch.exp(actions_logp - prev_actions_logp)
|
|
|
|
|
|
|
|
advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES])
|
|
|
|
surrogate_loss = torch.min(
|
|
|
|
advantages * logp_ratio,
|
|
|
|
advantages
|
|
|
|
* torch.clamp(
|
|
|
|
logp_ratio,
|
|
|
|
1 - policy.config["clip_param"],
|
|
|
|
1 + policy.config["clip_param"],
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2020-09-02 14:03:01 +02:00
|
|
|
)
|
|
|
|
|
2021-10-04 13:29:00 +02:00
|
|
|
mean_kl_loss = reduce_mean_valid(action_kl)
|
2020-09-02 14:03:01 +02:00
|
|
|
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
|
|
|
|
|
|
|
|
# The value function loss.
|
|
|
|
value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS])
|
|
|
|
delta = values_time_major - value_targets
|
|
|
|
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
|
|
|
|
|
|
|
|
# The entropy loss.
|
|
|
|
mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy()))
|
|
|
|
|
|
|
|
# The summed weighted loss
|
|
|
|
total_loss = (
|
|
|
|
mean_policy_loss
|
|
|
|
+ mean_vf_loss * policy.config["vf_loss_coeff"]
|
2021-10-20 14:18:01 -07:00
|
|
|
- mean_entropy * policy.entropy_coeff
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-09-02 14:03:01 +02:00
|
|
|
|
|
|
|
# Optional additional KL Loss
|
|
|
|
if policy.config["use_kl_loss"]:
|
2021-10-04 13:29:00 +02:00
|
|
|
total_loss += policy.kl_coeff * mean_kl_loss
|
|
|
|
|
|
|
|
# Store values for stats function in model (tower), such that for
|
|
|
|
# multi-GPU, we do not override them during the parallel loss phase.
|
|
|
|
model.tower_stats["total_loss"] = total_loss
|
|
|
|
model.tower_stats["mean_policy_loss"] = mean_policy_loss
|
|
|
|
model.tower_stats["mean_kl_loss"] = mean_kl_loss
|
|
|
|
model.tower_stats["mean_vf_loss"] = mean_vf_loss
|
|
|
|
model.tower_stats["mean_entropy"] = mean_entropy
|
|
|
|
model.tower_stats["value_targets"] = value_targets
|
|
|
|
model.tower_stats["vf_explained_var"] = explained_variance(
|
2021-08-18 17:21:01 +02:00
|
|
|
torch.reshape(value_targets, [-1]),
|
2021-11-03 10:01:34 +01:00
|
|
|
torch.reshape(values_time_major[:-1] if drop_last else values_time_major, [-1]),
|
2021-08-18 17:21:01 +02:00
|
|
|
)
|
2020-09-02 14:03:01 +02:00
|
|
|
|
|
|
|
return total_loss
|
|
|
|
|
|
|
|
|
|
|
|
def stats(policy: Policy, train_batch: SampleBatch):
|
|
|
|
"""Stats function for APPO. Returns a dict with important loss stats.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy to generate stats for.
|
|
|
|
train_batch (SampleBatch): The SampleBatch (already) used for training.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, TensorType]: The stats dict.
|
|
|
|
"""
|
2019-07-29 15:02:32 -07:00
|
|
|
stats_dict = {
|
2020-04-23 09:11:12 +02:00
|
|
|
"cur_lr": policy.cur_lr,
|
2021-10-04 13:29:00 +02:00
|
|
|
"total_loss": torch.mean(torch.stack(policy.get_tower_stats("total_loss"))),
|
|
|
|
"policy_loss": torch.mean(
|
|
|
|
torch.stack(policy.get_tower_stats("mean_policy_loss"))
|
|
|
|
),
|
|
|
|
"entropy": torch.mean(torch.stack(policy.get_tower_stats("mean_entropy"))),
|
2021-10-20 14:18:01 -07:00
|
|
|
"entropy_coeff": policy.entropy_coeff,
|
2020-04-23 09:11:12 +02:00
|
|
|
"var_gnorm": global_norm(policy.model.trainable_variables()),
|
2021-10-04 13:29:00 +02:00
|
|
|
"vf_loss": torch.mean(torch.stack(policy.get_tower_stats("mean_vf_loss"))),
|
|
|
|
"vf_explained_var": torch.mean(
|
|
|
|
torch.stack(policy.get_tower_stats("vf_explained_var"))
|
|
|
|
),
|
2019-07-29 15:02:32 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
if policy.config["vtrace"]:
|
2020-09-02 14:03:01 +02:00
|
|
|
is_stat_mean = torch.mean(policy._is_ratio, [0, 1])
|
|
|
|
is_stat_var = torch.var(policy._is_ratio, [0, 1])
|
2021-10-04 13:29:00 +02:00
|
|
|
stats_dict["mean_IS"] = is_stat_mean
|
|
|
|
stats_dict["var_IS"] = is_stat_var
|
2019-07-29 15:02:32 -07:00
|
|
|
|
|
|
|
if policy.config["use_kl_loss"]:
|
2022-01-27 11:08:58 -08:00
|
|
|
stats_dict["kl"] = torch.mean(
|
|
|
|
torch.stack(policy.get_tower_stats("mean_kl_loss"))
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-10-04 13:29:00 +02:00
|
|
|
stats_dict["KL_Coeff"] = policy.kl_coeff
|
2019-07-29 15:02:32 -07:00
|
|
|
|
|
|
|
return stats_dict
|
|
|
|
|
|
|
|
|
2020-04-23 09:11:12 +02:00
|
|
|
def add_values(policy, input_dict, state_batches, model, action_dist):
|
|
|
|
out = {}
|
|
|
|
if not policy.config["vtrace"]:
|
2021-08-18 17:21:01 +02:00
|
|
|
out[SampleBatch.VF_PREDS] = model.value_function()
|
2020-04-23 09:11:12 +02:00
|
|
|
return out
|
2019-10-31 15:16:02 -07:00
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
|
2021-11-16 14:49:41 +01:00
|
|
|
class KLCoeffMixin:
|
|
|
|
"""Assigns the `update_kl()` method to the PPOPolicy.
|
|
|
|
|
|
|
|
This is used in PPO's execution plan (see ppo.py) for updating the KL
|
|
|
|
coefficient after each learning step based on `config.kl_target` and
|
|
|
|
the measured KL value (from the train_batch).
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
# The current KL value (as python float).
|
|
|
|
self.kl_coeff = config["kl_coeff"]
|
|
|
|
# Constant target value.
|
|
|
|
self.kl_target = config["kl_target"]
|
|
|
|
|
|
|
|
def update_kl(self, sampled_kl):
|
|
|
|
# Update the current KL value based on the recently measured value.
|
|
|
|
if sampled_kl > 2.0 * self.kl_target:
|
|
|
|
self.kl_coeff *= 1.5
|
|
|
|
elif sampled_kl < 0.5 * self.kl_target:
|
|
|
|
self.kl_coeff *= 0.5
|
|
|
|
# Return the current KL value.
|
|
|
|
return self.kl_coeff
|
|
|
|
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
def setup_early_mixins(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
|
|
|
):
|
|
|
|
"""Call all mixin classes' constructors before APPOPolicy initialization.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy object.
|
|
|
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
|
|
|
action_space (gym.spaces.Space): The Policy's action space.
|
|
|
|
config (TrainerConfigDict): The Policy's config.
|
|
|
|
"""
|
2019-07-29 15:02:32 -07:00
|
|
|
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
2021-10-20 14:18:01 -07:00
|
|
|
EntropyCoeffSchedule.__init__(
|
|
|
|
policy, config["entropy_coeff"], config["entropy_coeff_schedule"]
|
|
|
|
)
|
2019-07-29 15:02:32 -07:00
|
|
|
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
def setup_late_mixins(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
|
|
|
):
|
|
|
|
"""Call all mixin classes' constructors after APPOPolicy initialization.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy object.
|
|
|
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
|
|
|
action_space (gym.spaces.Space): The Policy's action space.
|
|
|
|
config (TrainerConfigDict): The Policy's config.
|
|
|
|
"""
|
2020-04-23 09:11:12 +02:00
|
|
|
KLCoeffMixin.__init__(policy, config)
|
2022-05-17 08:16:08 -07:00
|
|
|
ValueNetworkMixin.__init__(policy, config)
|
2021-08-18 17:21:01 +02:00
|
|
|
TargetNetworkMixin.__init__(policy)
|
2019-07-29 15:02:32 -07:00
|
|
|
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
# Build a child class of `TorchPolicy`, given the custom functions defined
|
|
|
|
# above.
|
2020-12-26 20:14:18 -05:00
|
|
|
AsyncPPOTorchPolicy = build_policy_class(
|
2020-04-23 09:11:12 +02:00
|
|
|
name="AsyncPPOTorchPolicy",
|
2020-12-26 20:14:18 -05:00
|
|
|
framework="torch",
|
2020-09-02 14:03:01 +02:00
|
|
|
loss_fn=appo_surrogate_loss,
|
2019-07-29 15:02:32 -07:00
|
|
|
stats_fn=stats,
|
2019-05-18 00:23:11 -07:00
|
|
|
postprocess_fn=postprocess_trajectory,
|
2020-04-23 09:11:12 +02:00
|
|
|
extra_action_out_fn=add_values,
|
|
|
|
extra_grad_process_fn=apply_grad_clipping,
|
2019-07-29 15:02:32 -07:00
|
|
|
optimizer_fn=choose_optimizer,
|
2020-04-23 09:11:12 +02:00
|
|
|
before_init=setup_early_mixins,
|
2020-11-28 01:25:47 +01:00
|
|
|
before_loss_init=setup_late_mixins,
|
2020-09-02 14:03:01 +02:00
|
|
|
make_model=make_appo_model,
|
2019-07-29 15:02:32 -07:00
|
|
|
mixins=[
|
|
|
|
LearningRateSchedule,
|
|
|
|
KLCoeffMixin,
|
|
|
|
TargetNetworkMixin,
|
2021-10-20 14:18:01 -07:00
|
|
|
ValueNetworkMixin,
|
|
|
|
EntropyCoeffSchedule,
|
2019-07-29 15:02:32 -07:00
|
|
|
],
|
2020-03-14 12:05:04 -07:00
|
|
|
get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"],
|
|
|
|
)
|