mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Importance Sampling and KL Loss for APPO (#5051)
This commit is contained in:
parent
3b00144e7d
commit
1337c98f02
10 changed files with 288 additions and 40 deletions
|
@ -241,8 +241,9 @@ def add_behaviour_logits(policy):
|
||||||
|
|
||||||
|
|
||||||
def validate_config(policy, obs_space, action_space, config):
|
def validate_config(policy, obs_space, action_space, config):
|
||||||
assert config["batch_mode"] == "truncate_episodes", \
|
if config["vtrace"]:
|
||||||
"Must use `truncate_episodes` batch mode with V-trace."
|
assert config["batch_mode"] == "truncate_episodes", \
|
||||||
|
"Must use `truncate_episodes` batch mode with V-trace."
|
||||||
|
|
||||||
|
|
||||||
def choose_optimizer(policy, config):
|
def choose_optimizer(policy, config):
|
||||||
|
|
|
@ -4,6 +4,7 @@ from __future__ import print_function
|
||||||
|
|
||||||
from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy
|
from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy
|
||||||
from ray.rllib.agents.trainer import with_base_config
|
from ray.rllib.agents.trainer import with_base_config
|
||||||
|
from ray.rllib.agents.ppo.ppo import update_kl
|
||||||
from ray.rllib.agents import impala
|
from ray.rllib.agents import impala
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
|
@ -23,12 +24,17 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
|
||||||
# == PPO surrogate loss options ==
|
# == PPO surrogate loss options ==
|
||||||
"clip_param": 0.4,
|
"clip_param": 0.4,
|
||||||
|
|
||||||
|
# == PPO KL Loss options ==
|
||||||
|
"use_kl_loss": False,
|
||||||
|
"kl_coeff": 1.0,
|
||||||
|
"kl_target": 0.01,
|
||||||
|
|
||||||
# == IMPALA optimizer params (see documentation in impala.py) ==
|
# == IMPALA optimizer params (see documentation in impala.py) ==
|
||||||
"sample_batch_size": 50,
|
"sample_batch_size": 50,
|
||||||
"train_batch_size": 500,
|
"train_batch_size": 500,
|
||||||
"min_iter_time_s": 10,
|
"min_iter_time_s": 10,
|
||||||
"num_workers": 2,
|
"num_workers": 2,
|
||||||
"num_gpus": 1,
|
"num_gpus": 0,
|
||||||
"num_data_loader_buffers": 1,
|
"num_data_loader_buffers": 1,
|
||||||
"minibatch_buffer_size": 1,
|
"minibatch_buffer_size": 1,
|
||||||
"num_sgd_iter": 1,
|
"num_sgd_iter": 1,
|
||||||
|
@ -52,8 +58,34 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
|
||||||
# __sphinx_doc_end__
|
# __sphinx_doc_end__
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
|
|
||||||
|
def update_target_and_kl(trainer, fetches):
|
||||||
|
# Update the KL coeff depending on how many steps LearnerThread has stepped
|
||||||
|
# through
|
||||||
|
learner_steps = trainer.optimizer.learner.num_steps
|
||||||
|
if learner_steps >= trainer.target_update_frequency:
|
||||||
|
|
||||||
|
# Update Target Network
|
||||||
|
trainer.optimizer.learner.num_steps = 0
|
||||||
|
trainer.workers.local_worker().foreach_trainable_policy(
|
||||||
|
lambda p, _: p.update_target())
|
||||||
|
|
||||||
|
# Also update KL Coeff
|
||||||
|
if trainer.config["use_kl_loss"]:
|
||||||
|
update_kl(trainer, trainer.optimizer.learner.stats)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_target(trainer):
|
||||||
|
trainer.workers.local_worker().foreach_trainable_policy(
|
||||||
|
lambda p, _: p.update_target())
|
||||||
|
trainer.target_update_frequency = trainer.config["num_sgd_iter"] \
|
||||||
|
* trainer.config["minibatch_buffer_size"]
|
||||||
|
|
||||||
|
|
||||||
APPOTrainer = impala.ImpalaTrainer.with_updates(
|
APPOTrainer = impala.ImpalaTrainer.with_updates(
|
||||||
name="APPO",
|
name="APPO",
|
||||||
default_config=DEFAULT_CONFIG,
|
default_config=DEFAULT_CONFIG,
|
||||||
default_policy=AsyncPPOTFPolicy,
|
default_policy=AsyncPPOTFPolicy,
|
||||||
get_policy_class=lambda _: AsyncPPOTFPolicy)
|
get_policy_class=lambda _: AsyncPPOTFPolicy,
|
||||||
|
after_init=initialize_target,
|
||||||
|
after_optimizer_step=update_target_and_kl)
|
||||||
|
|
|
@ -12,15 +12,24 @@ import gym
|
||||||
|
|
||||||
from ray.rllib.agents.impala import vtrace
|
from ray.rllib.agents.impala import vtrace
|
||||||
from ray.rllib.agents.impala.vtrace_policy import _make_time_major, \
|
from ray.rllib.agents.impala.vtrace_policy import _make_time_major, \
|
||||||
BEHAVIOUR_LOGITS, VTraceTFPolicy
|
BEHAVIOUR_LOGITS, clip_gradients, \
|
||||||
|
validate_config, choose_optimizer, ValueNetworkMixin
|
||||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||||
from ray.rllib.utils import try_import_tf
|
from ray.rllib.utils import try_import_tf
|
||||||
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||||
|
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
||||||
|
from ray.rllib.agents.ppo.ppo_policy import KLCoeffMixin
|
||||||
|
from ray.rllib.models import ModelCatalog
|
||||||
|
from ray.rllib.utils.explained_variance import explained_variance
|
||||||
|
|
||||||
tf = try_import_tf()
|
tf = try_import_tf()
|
||||||
|
|
||||||
|
POLICY_SCOPE = "func"
|
||||||
|
TARGET_POLICY_SCOPE = "target_func"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,6 +45,11 @@ class PPOSurrogateLoss(object):
|
||||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||||
advantages: A float32 tensor of shape [T, B].
|
advantages: A float32 tensor of shape [T, B].
|
||||||
value_targets: A float32 tensor of shape [T, B].
|
value_targets: A float32 tensor of shape [T, B].
|
||||||
|
vf_loss_coeff (float): Coefficient of the value function loss.
|
||||||
|
entropy_coeff (float): Coefficient of the entropy regularizer.
|
||||||
|
clip_param (float): Clip parameter.
|
||||||
|
cur_kl_coeff (float): Coefficient for KL loss.
|
||||||
|
use_kl_loss (bool): If true, use KL loss.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -49,7 +63,11 @@ class PPOSurrogateLoss(object):
|
||||||
value_targets,
|
value_targets,
|
||||||
vf_loss_coeff=0.5,
|
vf_loss_coeff=0.5,
|
||||||
entropy_coeff=0.01,
|
entropy_coeff=0.01,
|
||||||
clip_param=0.3):
|
clip_param=0.3,
|
||||||
|
cur_kl_coeff=None,
|
||||||
|
use_kl_loss=False):
|
||||||
|
def reduce_mean_valid(t):
|
||||||
|
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
|
||||||
|
|
||||||
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
|
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
|
||||||
|
|
||||||
|
@ -58,32 +76,37 @@ class PPOSurrogateLoss(object):
|
||||||
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
||||||
1 + clip_param))
|
1 + clip_param))
|
||||||
|
|
||||||
self.mean_kl = tf.reduce_mean(action_kl)
|
self.mean_kl = reduce_mean_valid(action_kl)
|
||||||
self.pi_loss = -tf.reduce_sum(surrogate_loss)
|
self.pi_loss = -reduce_mean_valid(surrogate_loss)
|
||||||
|
|
||||||
# The baseline loss
|
# The baseline loss
|
||||||
delta = tf.boolean_mask(values - value_targets, valid_mask)
|
delta = values - value_targets
|
||||||
self.value_targets = value_targets
|
self.value_targets = value_targets
|
||||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
self.vf_loss = 0.5 * reduce_mean_valid(tf.square(delta))
|
||||||
|
|
||||||
# The entropy loss
|
# The entropy loss
|
||||||
self.entropy = tf.reduce_sum(
|
self.entropy = reduce_mean_valid(actions_entropy)
|
||||||
tf.boolean_mask(actions_entropy, valid_mask))
|
|
||||||
|
|
||||||
# The summed weighted loss
|
# The summed weighted loss
|
||||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||||
self.entropy * entropy_coeff)
|
self.entropy * entropy_coeff)
|
||||||
|
|
||||||
|
# Optional additional KL Loss
|
||||||
|
if use_kl_loss:
|
||||||
|
self.total_loss += cur_kl_coeff * self.mean_kl
|
||||||
|
|
||||||
|
|
||||||
class VTraceSurrogateLoss(object):
|
class VTraceSurrogateLoss(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
actions,
|
actions,
|
||||||
prev_actions_logp,
|
prev_actions_logp,
|
||||||
actions_logp,
|
actions_logp,
|
||||||
|
old_policy_actions_logp,
|
||||||
action_kl,
|
action_kl,
|
||||||
actions_entropy,
|
actions_entropy,
|
||||||
dones,
|
dones,
|
||||||
behaviour_logits,
|
behaviour_logits,
|
||||||
|
old_policy_behaviour_logits,
|
||||||
target_logits,
|
target_logits,
|
||||||
discount,
|
discount,
|
||||||
rewards,
|
rewards,
|
||||||
|
@ -95,8 +118,10 @@ class VTraceSurrogateLoss(object):
|
||||||
entropy_coeff=0.01,
|
entropy_coeff=0.01,
|
||||||
clip_rho_threshold=1.0,
|
clip_rho_threshold=1.0,
|
||||||
clip_pg_rho_threshold=1.0,
|
clip_pg_rho_threshold=1.0,
|
||||||
clip_param=0.3):
|
clip_param=0.3,
|
||||||
"""PPO surrogate loss with vtrace importance weighting.
|
cur_kl_coeff=None,
|
||||||
|
use_kl_loss=False):
|
||||||
|
"""APPO Loss, with IS modifications and V-trace for Advantage Estimation
|
||||||
|
|
||||||
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
|
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
|
||||||
batch_size. The reason we need to know `B` is for V-trace to properly
|
batch_size. The reason we need to know `B` is for V-trace to properly
|
||||||
|
@ -106,10 +131,13 @@ class VTraceSurrogateLoss(object):
|
||||||
actions: An int|float32 tensor of shape [T, B, logit_dim].
|
actions: An int|float32 tensor of shape [T, B, logit_dim].
|
||||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||||
actions_logp: A float32 tensor of shape [T, B].
|
actions_logp: A float32 tensor of shape [T, B].
|
||||||
|
old_policy_actions_logp: A float32 tensor of shape [T, B].
|
||||||
action_kl: A float32 tensor of shape [T, B].
|
action_kl: A float32 tensor of shape [T, B].
|
||||||
actions_entropy: A float32 tensor of shape [T, B].
|
actions_entropy: A float32 tensor of shape [T, B].
|
||||||
dones: A bool tensor of shape [T, B].
|
dones: A bool tensor of shape [T, B].
|
||||||
behaviour_logits: A float32 tensor of shape [T, B, logit_dim].
|
behaviour_logits: A float32 tensor of shape [T, B, logit_dim].
|
||||||
|
old_policy_behaviour_logits: A float32 tensor of shape
|
||||||
|
[T, B, logit_dim].
|
||||||
target_logits: A float32 tensor of shape [T, B, logit_dim].
|
target_logits: A float32 tensor of shape [T, B, logit_dim].
|
||||||
discount: A float32 scalar.
|
discount: A float32 scalar.
|
||||||
rewards: A float32 tensor of shape [T, B].
|
rewards: A float32 tensor of shape [T, B].
|
||||||
|
@ -117,13 +145,21 @@ class VTraceSurrogateLoss(object):
|
||||||
bootstrap_value: A float32 tensor of shape [B].
|
bootstrap_value: A float32 tensor of shape [B].
|
||||||
dist_class: action distribution class for logits.
|
dist_class: action distribution class for logits.
|
||||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||||
|
vf_loss_coeff (float): Coefficient of the value function loss.
|
||||||
|
entropy_coeff (float): Coefficient of the entropy regularizer.
|
||||||
|
clip_param (float): Clip parameter.
|
||||||
|
cur_kl_coeff (float): Coefficient for KL loss.
|
||||||
|
use_kl_loss (bool): If true, use KL loss.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def reduce_mean_valid(t):
|
||||||
|
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
|
||||||
|
|
||||||
# Compute vtrace on the CPU for better perf.
|
# Compute vtrace on the CPU for better perf.
|
||||||
with tf.device("/cpu:0"):
|
with tf.device("/cpu:0"):
|
||||||
self.vtrace_returns = vtrace.multi_from_logits(
|
self.vtrace_returns = vtrace.multi_from_logits(
|
||||||
behaviour_policy_logits=behaviour_logits,
|
behaviour_policy_logits=behaviour_logits,
|
||||||
target_policy_logits=target_logits,
|
target_policy_logits=old_policy_behaviour_logits,
|
||||||
actions=tf.unstack(actions, axis=2),
|
actions=tf.unstack(actions, axis=2),
|
||||||
discounts=tf.to_float(~dones) * discount,
|
discounts=tf.to_float(~dones) * discount,
|
||||||
rewards=rewards,
|
rewards=rewards,
|
||||||
|
@ -134,7 +170,9 @@ class VTraceSurrogateLoss(object):
|
||||||
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
||||||
tf.float32))
|
tf.float32))
|
||||||
|
|
||||||
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
|
self.is_ratio = tf.clip_by_value(
|
||||||
|
tf.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
|
||||||
|
logp_ratio = self.is_ratio * tf.exp(actions_logp - prev_actions_logp)
|
||||||
|
|
||||||
advantages = self.vtrace_returns.pg_advantages
|
advantages = self.vtrace_returns.pg_advantages
|
||||||
surrogate_loss = tf.minimum(
|
surrogate_loss = tf.minimum(
|
||||||
|
@ -142,22 +180,45 @@ class VTraceSurrogateLoss(object):
|
||||||
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
||||||
1 + clip_param))
|
1 + clip_param))
|
||||||
|
|
||||||
self.mean_kl = tf.reduce_mean(action_kl)
|
self.mean_kl = reduce_mean_valid(action_kl)
|
||||||
self.pi_loss = -tf.reduce_sum(surrogate_loss)
|
self.pi_loss = -reduce_mean_valid(surrogate_loss)
|
||||||
|
|
||||||
# The baseline loss
|
# The baseline loss
|
||||||
delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask)
|
delta = values - self.vtrace_returns.vs
|
||||||
self.value_targets = self.vtrace_returns.vs
|
self.value_targets = self.vtrace_returns.vs
|
||||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
self.vf_loss = 0.5 * reduce_mean_valid(tf.square(delta))
|
||||||
|
|
||||||
# The entropy loss
|
# The entropy loss
|
||||||
self.entropy = tf.reduce_sum(
|
self.entropy = reduce_mean_valid(actions_entropy)
|
||||||
tf.boolean_mask(actions_entropy, valid_mask))
|
|
||||||
|
|
||||||
# The summed weighted loss
|
# The summed weighted loss
|
||||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||||
self.entropy * entropy_coeff)
|
self.entropy * entropy_coeff)
|
||||||
|
|
||||||
|
# Optional additional KL Loss
|
||||||
|
if use_kl_loss:
|
||||||
|
self.total_loss += cur_kl_coeff * self.mean_kl
|
||||||
|
|
||||||
|
|
||||||
|
def build_appo_model(policy, obs_space, action_space, config):
|
||||||
|
policy.model = ModelCatalog.get_model_v2(
|
||||||
|
obs_space,
|
||||||
|
action_space,
|
||||||
|
policy.logit_dim,
|
||||||
|
config["model"],
|
||||||
|
name=POLICY_SCOPE,
|
||||||
|
framework="tf")
|
||||||
|
|
||||||
|
policy.target_model = ModelCatalog.get_model_v2(
|
||||||
|
obs_space,
|
||||||
|
action_space,
|
||||||
|
policy.logit_dim,
|
||||||
|
config["model"],
|
||||||
|
name=TARGET_POLICY_SCOPE,
|
||||||
|
framework="tf")
|
||||||
|
|
||||||
|
return policy.model
|
||||||
|
|
||||||
|
|
||||||
def build_appo_surrogate_loss(policy, batch_tensors):
|
def build_appo_surrogate_loss(policy, batch_tensors):
|
||||||
if isinstance(policy.action_space, gym.spaces.Discrete):
|
if isinstance(policy.action_space, gym.spaces.Discrete):
|
||||||
|
@ -177,14 +238,26 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||||
actions = batch_tensors[SampleBatch.ACTIONS]
|
actions = batch_tensors[SampleBatch.ACTIONS]
|
||||||
dones = batch_tensors[SampleBatch.DONES]
|
dones = batch_tensors[SampleBatch.DONES]
|
||||||
rewards = batch_tensors[SampleBatch.REWARDS]
|
rewards = batch_tensors[SampleBatch.REWARDS]
|
||||||
|
|
||||||
behaviour_logits = batch_tensors[BEHAVIOUR_LOGITS]
|
behaviour_logits = batch_tensors[BEHAVIOUR_LOGITS]
|
||||||
|
|
||||||
|
policy.target_model_out, _ = policy.target_model(
|
||||||
|
policy.input_dict, policy.state_in, policy.seq_lens)
|
||||||
|
old_policy_behaviour_logits = tf.stop_gradient(policy.target_model_out)
|
||||||
|
|
||||||
unpacked_behaviour_logits = tf.split(
|
unpacked_behaviour_logits = tf.split(
|
||||||
behaviour_logits, output_hidden_shape, axis=1)
|
behaviour_logits, output_hidden_shape, axis=1)
|
||||||
|
unpacked_old_policy_behaviour_logits = tf.split(
|
||||||
|
old_policy_behaviour_logits, output_hidden_shape, axis=1)
|
||||||
unpacked_outputs = tf.split(policy.model_out, output_hidden_shape, axis=1)
|
unpacked_outputs = tf.split(policy.model_out, output_hidden_shape, axis=1)
|
||||||
action_dist = policy.action_dist
|
action_dist = policy.action_dist
|
||||||
|
old_policy_action_dist = policy.dist_class(old_policy_behaviour_logits)
|
||||||
prev_action_dist = policy.dist_class(behaviour_logits)
|
prev_action_dist = policy.dist_class(behaviour_logits)
|
||||||
values = policy.value_function
|
values = policy.value_function
|
||||||
|
|
||||||
|
policy.model_vars = policy.model.variables()
|
||||||
|
policy.target_model_vars = policy.target_model.variables()
|
||||||
|
|
||||||
if policy.state_in:
|
if policy.state_in:
|
||||||
max_seq_len = tf.reduce_max(policy.seq_lens) - 1
|
max_seq_len = tf.reduce_max(policy.seq_lens) - 1
|
||||||
mask = tf.sequence_mask(policy.seq_lens, max_seq_len)
|
mask = tf.sequence_mask(policy.seq_lens, max_seq_len)
|
||||||
|
@ -199,18 +272,27 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||||
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
||||||
actions, axis=1)
|
actions, axis=1)
|
||||||
|
|
||||||
|
# Prepare KL for Loss
|
||||||
|
mean_kl = make_time_major(
|
||||||
|
old_policy_action_dist.multi_kl(action_dist), drop_last=True)
|
||||||
|
|
||||||
policy.loss = VTraceSurrogateLoss(
|
policy.loss = VTraceSurrogateLoss(
|
||||||
actions=make_time_major(loss_actions, drop_last=True),
|
actions=make_time_major(loss_actions, drop_last=True),
|
||||||
prev_actions_logp=make_time_major(
|
prev_actions_logp=make_time_major(
|
||||||
prev_action_dist.logp(actions), drop_last=True),
|
prev_action_dist.logp(actions), drop_last=True),
|
||||||
actions_logp=make_time_major(
|
actions_logp=make_time_major(
|
||||||
action_dist.logp(actions), drop_last=True),
|
action_dist.logp(actions), drop_last=True),
|
||||||
action_kl=prev_action_dist.multi_kl(action_dist),
|
old_policy_actions_logp=make_time_major(
|
||||||
|
old_policy_action_dist.logp(actions), drop_last=True),
|
||||||
|
action_kl=tf.reduce_mean(mean_kl, axis=0)
|
||||||
|
if is_multidiscrete else mean_kl,
|
||||||
actions_entropy=make_time_major(
|
actions_entropy=make_time_major(
|
||||||
action_dist.multi_entropy(), drop_last=True),
|
action_dist.multi_entropy(), drop_last=True),
|
||||||
dones=make_time_major(dones, drop_last=True),
|
dones=make_time_major(dones, drop_last=True),
|
||||||
behaviour_logits=make_time_major(
|
behaviour_logits=make_time_major(
|
||||||
unpacked_behaviour_logits, drop_last=True),
|
unpacked_behaviour_logits, drop_last=True),
|
||||||
|
old_policy_behaviour_logits=make_time_major(
|
||||||
|
unpacked_old_policy_behaviour_logits, drop_last=True),
|
||||||
target_logits=make_time_major(unpacked_outputs, drop_last=True),
|
target_logits=make_time_major(unpacked_outputs, drop_last=True),
|
||||||
discount=policy.config["gamma"],
|
discount=policy.config["gamma"],
|
||||||
rewards=make_time_major(rewards, drop_last=True),
|
rewards=make_time_major(rewards, drop_last=True),
|
||||||
|
@ -219,17 +301,24 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||||
dist_class=Categorical if is_multidiscrete else policy.dist_class,
|
dist_class=Categorical if is_multidiscrete else policy.dist_class,
|
||||||
valid_mask=make_time_major(mask, drop_last=True),
|
valid_mask=make_time_major(mask, drop_last=True),
|
||||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||||
entropy_coeff=policy.entropy_coeff,
|
entropy_coeff=policy.config["entropy_coeff"],
|
||||||
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
|
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
|
||||||
clip_pg_rho_threshold=policy.config[
|
clip_pg_rho_threshold=policy.config[
|
||||||
"vtrace_clip_pg_rho_threshold"],
|
"vtrace_clip_pg_rho_threshold"],
|
||||||
clip_param=policy.config["clip_param"])
|
clip_param=policy.config["clip_param"],
|
||||||
|
cur_kl_coeff=policy.kl_coeff,
|
||||||
|
use_kl_loss=policy.config["use_kl_loss"])
|
||||||
else:
|
else:
|
||||||
logger.info("Using PPO surrogate loss (vtrace=False)")
|
logger.info("Using PPO surrogate loss (vtrace=False)")
|
||||||
|
|
||||||
|
# Prepare KL for Loss
|
||||||
|
mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))
|
||||||
|
|
||||||
policy.loss = PPOSurrogateLoss(
|
policy.loss = PPOSurrogateLoss(
|
||||||
prev_actions_logp=make_time_major(prev_action_dist.logp(actions)),
|
prev_actions_logp=make_time_major(prev_action_dist.logp(actions)),
|
||||||
actions_logp=make_time_major(action_dist.logp(actions)),
|
actions_logp=make_time_major(action_dist.logp(actions)),
|
||||||
action_kl=prev_action_dist.multi_kl(action_dist),
|
action_kl=tf.reduce_mean(mean_kl, axis=0)
|
||||||
|
if is_multidiscrete else mean_kl,
|
||||||
actions_entropy=make_time_major(action_dist.multi_entropy()),
|
actions_entropy=make_time_major(action_dist.multi_entropy()),
|
||||||
values=make_time_major(values),
|
values=make_time_major(values),
|
||||||
valid_mask=make_time_major(mask),
|
valid_mask=make_time_major(mask),
|
||||||
|
@ -238,12 +327,41 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||||
value_targets=make_time_major(
|
value_targets=make_time_major(
|
||||||
batch_tensors[Postprocessing.VALUE_TARGETS]),
|
batch_tensors[Postprocessing.VALUE_TARGETS]),
|
||||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||||
entropy_coeff=policy.entropy_coeff,
|
entropy_coeff=policy.config["entropy_coeff"],
|
||||||
clip_param=policy.config["clip_param"])
|
clip_param=policy.config["clip_param"],
|
||||||
|
cur_kl_coeff=policy.kl_coeff,
|
||||||
|
use_kl_loss=policy.config["use_kl_loss"])
|
||||||
|
|
||||||
return policy.loss.total_loss
|
return policy.loss.total_loss
|
||||||
|
|
||||||
|
|
||||||
|
def stats(policy, batch_tensors):
|
||||||
|
values_batched = _make_time_major(
|
||||||
|
policy, policy.value_function, drop_last=policy.config["vtrace"])
|
||||||
|
|
||||||
|
stats_dict = {
|
||||||
|
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
||||||
|
"policy_loss": policy.loss.pi_loss,
|
||||||
|
"entropy": policy.loss.entropy,
|
||||||
|
"var_gnorm": tf.global_norm(policy.var_list),
|
||||||
|
"vf_loss": policy.loss.vf_loss,
|
||||||
|
"vf_explained_var": explained_variance(
|
||||||
|
tf.reshape(policy.loss.value_targets, [-1]),
|
||||||
|
tf.reshape(values_batched, [-1])),
|
||||||
|
}
|
||||||
|
|
||||||
|
if policy.config["vtrace"]:
|
||||||
|
is_stat_mean, is_stat_var = tf.nn.moments(policy.loss.is_ratio, [0, 1])
|
||||||
|
stats_dict.update({"mean_IS": is_stat_mean})
|
||||||
|
stats_dict.update({"var_IS": is_stat_var})
|
||||||
|
|
||||||
|
if policy.config["use_kl_loss"]:
|
||||||
|
stats_dict.update({"kl": policy.loss.mean_kl})
|
||||||
|
stats_dict.update({"KL_Coeff": policy.kl_coeff})
|
||||||
|
|
||||||
|
return stats_dict
|
||||||
|
|
||||||
|
|
||||||
def postprocess_trajectory(policy,
|
def postprocess_trajectory(policy,
|
||||||
sample_batch,
|
sample_batch,
|
||||||
other_agent_batches=None,
|
other_agent_batches=None,
|
||||||
|
@ -276,8 +394,47 @@ def add_values_and_logits(policy):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
AsyncPPOTFPolicy = VTraceTFPolicy.with_updates(
|
class TargetNetworkMixin(object):
|
||||||
|
def __init__(self, obs_space, action_space, config):
|
||||||
|
"""Target Network is updated by the master learner every
|
||||||
|
trainer.update_target_frequency steps. All worker batches
|
||||||
|
are importance sampled w.r. to the target network to ensure
|
||||||
|
a more stable pi_old in PPO.
|
||||||
|
"""
|
||||||
|
assign_ops = []
|
||||||
|
assert len(self.model_vars) == len(self.target_model_vars)
|
||||||
|
for var, var_target in zip(self.model_vars, self.target_model_vars):
|
||||||
|
assign_ops.append(var_target.assign(var))
|
||||||
|
self.update_target_network = tf.group(*assign_ops)
|
||||||
|
|
||||||
|
def update_target(self):
|
||||||
|
return self.get_session().run(self.update_target_network)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_mixins(policy, obs_space, action_space, config):
|
||||||
|
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||||
|
KLCoeffMixin.__init__(policy, config)
|
||||||
|
ValueNetworkMixin.__init__(policy)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_late_mixins(policy, obs_space, action_space, config):
|
||||||
|
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||||
|
|
||||||
|
|
||||||
|
AsyncPPOTFPolicy = build_tf_policy(
|
||||||
name="AsyncPPOTFPolicy",
|
name="AsyncPPOTFPolicy",
|
||||||
|
make_model=build_appo_model,
|
||||||
loss_fn=build_appo_surrogate_loss,
|
loss_fn=build_appo_surrogate_loss,
|
||||||
|
stats_fn=stats,
|
||||||
postprocess_fn=postprocess_trajectory,
|
postprocess_fn=postprocess_trajectory,
|
||||||
extra_action_fetches_fn=add_values_and_logits)
|
optimizer_fn=choose_optimizer,
|
||||||
|
gradients_fn=clip_gradients,
|
||||||
|
extra_action_fetches_fn=add_values_and_logits,
|
||||||
|
before_init=validate_config,
|
||||||
|
before_loss_init=setup_mixins,
|
||||||
|
after_init=setup_late_mixins,
|
||||||
|
mixins=[
|
||||||
|
LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin,
|
||||||
|
ValueNetworkMixin
|
||||||
|
],
|
||||||
|
get_batch_divisibility_req=lambda p: p.config["sample_batch_size"])
|
||||||
|
|
|
@ -49,7 +49,8 @@ class LearnerThread(threading.Thread):
|
||||||
inqueue=self.inqueue,
|
inqueue=self.inqueue,
|
||||||
size=minibatch_buffer_size,
|
size=minibatch_buffer_size,
|
||||||
timeout=learner_queue_timeout,
|
timeout=learner_queue_timeout,
|
||||||
num_passes=num_sgd_iter)
|
num_passes=num_sgd_iter,
|
||||||
|
init_num_passes=num_sgd_iter)
|
||||||
self.queue_timer = TimerStat()
|
self.queue_timer = TimerStat()
|
||||||
self.grad_timer = TimerStat()
|
self.grad_timer = TimerStat()
|
||||||
self.load_timer = TimerStat()
|
self.load_timer = TimerStat()
|
||||||
|
@ -58,6 +59,7 @@ class LearnerThread(threading.Thread):
|
||||||
self.weights_updated = False
|
self.weights_updated = False
|
||||||
self.stats = {}
|
self.stats = {}
|
||||||
self.stopped = False
|
self.stopped = False
|
||||||
|
self.num_steps = 0
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while not self.stopped:
|
while not self.stopped:
|
||||||
|
@ -72,5 +74,6 @@ class LearnerThread(threading.Thread):
|
||||||
self.weights_updated = True
|
self.weights_updated = True
|
||||||
self.stats = get_learner_stats(fetches)
|
self.stats = get_learner_stats(fetches)
|
||||||
|
|
||||||
|
self.num_steps += 1
|
||||||
self.outqueue.put(batch.count)
|
self.outqueue.put(batch.count)
|
||||||
self.learner_queue_size.push(self.inqueue.qsize())
|
self.learner_queue_size.push(self.inqueue.qsize())
|
||||||
|
|
|
@ -11,7 +11,7 @@ class MinibatchBuffer(object):
|
||||||
This is for use with AsyncSamplesOptimizer.
|
This is for use with AsyncSamplesOptimizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, inqueue, size, timeout, num_passes):
|
def __init__(self, inqueue, size, timeout, num_passes, init_num_passes=1):
|
||||||
"""Initialize a minibatch buffer.
|
"""Initialize a minibatch buffer.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -19,12 +19,13 @@ class MinibatchBuffer(object):
|
||||||
size: Max number of data items to buffer.
|
size: Max number of data items to buffer.
|
||||||
timeout: Queue timeout
|
timeout: Queue timeout
|
||||||
num_passes: Max num times each data item should be emitted.
|
num_passes: Max num times each data item should be emitted.
|
||||||
"""
|
init_num_passes: Initial max passes for each data item
|
||||||
|
"""
|
||||||
self.inqueue = inqueue
|
self.inqueue = inqueue
|
||||||
self.size = size
|
self.size = size
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.max_ttl = num_passes
|
self.max_ttl = num_passes
|
||||||
self.cur_max_ttl = 1 # ramp up slowly to better mix the input data
|
self.cur_max_ttl = init_num_passes
|
||||||
self.buffers = [None] * size
|
self.buffers = [None] * size
|
||||||
self.ttl = [0] * size
|
self.ttl = [0] * size
|
||||||
self.idx = 0
|
self.idx = 0
|
||||||
|
|
|
@ -131,6 +131,8 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
else:
|
else:
|
||||||
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||||
action_space, self.config["model"])
|
action_space, self.config["model"])
|
||||||
|
self.logit_dim = logit_dim
|
||||||
|
|
||||||
if existing_model:
|
if existing_model:
|
||||||
self.model = existing_model
|
self.model = existing_model
|
||||||
elif make_model:
|
elif make_model:
|
||||||
|
|
|
@ -430,9 +430,11 @@ class TFPolicy(Policy):
|
||||||
builder.add_feed_dict({self._obs_input: obs_batch})
|
builder.add_feed_dict({self._obs_input: obs_batch})
|
||||||
if state_batches:
|
if state_batches:
|
||||||
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
|
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
|
||||||
if self._prev_action_input is not None and prev_action_batch:
|
if self._prev_action_input is not None and \
|
||||||
|
prev_action_batch is not None:
|
||||||
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
|
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
|
||||||
if self._prev_reward_input is not None and prev_reward_batch:
|
if self._prev_reward_input is not None and \
|
||||||
|
prev_reward_batch is not None:
|
||||||
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
|
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
|
||||||
builder.add_feed_dict({self._is_training: False})
|
builder.add_feed_dict({self._is_training: False})
|
||||||
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
||||||
|
|
35
python/ray/rllib/tuned_examples/halfcheetah-appo.yaml
Normal file
35
python/ray/rllib/tuned_examples/halfcheetah-appo.yaml
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# This can reach 9k reward in 2 hours on a Titan XP GPU
|
||||||
|
# with 16 workers and 8 envs per worker.
|
||||||
|
halfcheetah-appo:
|
||||||
|
env: HalfCheetah-v2
|
||||||
|
run: APPO
|
||||||
|
stop:
|
||||||
|
time_total_s: 10800
|
||||||
|
config:
|
||||||
|
vtrace: True
|
||||||
|
gamma: 0.99
|
||||||
|
lambda: 0.95
|
||||||
|
sample_batch_size: 512
|
||||||
|
train_batch_size: 4096
|
||||||
|
num_workers: 16
|
||||||
|
num_gpus: 1
|
||||||
|
broadcast_interval: 1
|
||||||
|
max_sample_requests_in_flight_per_worker: 1
|
||||||
|
num_data_loader_buffers: 1
|
||||||
|
num_envs_per_worker: 32
|
||||||
|
minibatch_buffer_size: 16
|
||||||
|
num_sgd_iter: 32
|
||||||
|
clip_param: 0.2
|
||||||
|
lr_schedule: [
|
||||||
|
[0, 0.0005],
|
||||||
|
[150000000, 0.000001],
|
||||||
|
]
|
||||||
|
vf_loss_coeff: 0.5
|
||||||
|
entropy_coeff: 0.01
|
||||||
|
grad_clip: 0.5
|
||||||
|
batch_mode: truncate_episodes
|
||||||
|
use_kl_loss: True
|
||||||
|
kl_coeff: 1.0
|
||||||
|
kl_target: 0.04
|
||||||
|
observation_filter: MeanStdFilter
|
||||||
|
|
|
@ -1,3 +1,8 @@
|
||||||
|
# This can reach 18-19 reward in ~5-7 minutes on a Titan XP GPU
|
||||||
|
# with 32 workers and 8 envs per worker. IMPALA, when ran with
|
||||||
|
# similar configurations, solved Pong in 10-12 minutes.
|
||||||
|
# APPO can also solve Pong in 2.5 million timesteps, which is
|
||||||
|
# 2x more efficient than that of IMPALA.
|
||||||
pong-appo:
|
pong-appo:
|
||||||
env: PongNoFrameskip-v4
|
env: PongNoFrameskip-v4
|
||||||
run: APPO
|
run: APPO
|
||||||
|
@ -5,13 +10,15 @@ pong-appo:
|
||||||
episode_reward_mean: 18.0
|
episode_reward_mean: 18.0
|
||||||
timesteps_total: 5000000
|
timesteps_total: 5000000
|
||||||
config:
|
config:
|
||||||
|
vtrace: True
|
||||||
|
use_kl_loss: False
|
||||||
sample_batch_size: 50
|
sample_batch_size: 50
|
||||||
train_batch_size: 750
|
train_batch_size: 750
|
||||||
num_workers: 32
|
num_workers: 32
|
||||||
broadcast_interval: 1
|
broadcast_interval: 1
|
||||||
max_sample_requests_in_flight_per_worker: 1
|
max_sample_requests_in_flight_per_worker: 1
|
||||||
num_data_loader_buffers: 1
|
num_data_loader_buffers: 1
|
||||||
num_envs_per_worker: 5
|
num_envs_per_worker: 8
|
||||||
minibatch_buffer_size: 4
|
minibatch_buffer_size: 4
|
||||||
num_sgd_iter: 2
|
num_sgd_iter: 2
|
||||||
vf_loss_coeff: 1.0
|
vf_loss_coeff: 1.0
|
||||||
|
|
|
@ -2,11 +2,19 @@ pendulum-appo-vt:
|
||||||
env: Pendulum-v0
|
env: Pendulum-v0
|
||||||
run: APPO
|
run: APPO
|
||||||
stop:
|
stop:
|
||||||
episode_reward_mean: -900 # just check it learns a bit
|
episode_reward_mean: -1200 # just check it learns a bit
|
||||||
timesteps_total: 500000
|
timesteps_total: 500000
|
||||||
config:
|
config:
|
||||||
|
vtrace: False
|
||||||
num_gpus: 0
|
num_gpus: 0
|
||||||
num_workers: 1
|
num_workers: 1
|
||||||
|
lambda: 0.1
|
||||||
gamma: 0.95
|
gamma: 0.95
|
||||||
train_batch_size: 50
|
lr: 0.0003
|
||||||
vtrace: true
|
train_batch_size: 100
|
||||||
|
minibatch_buffer_size: 16
|
||||||
|
num_sgd_iter: 10
|
||||||
|
model:
|
||||||
|
fcnet_hiddens: [64, 64]
|
||||||
|
batch_mode: complete_episodes
|
||||||
|
observation_filter: MeanStdFilter
|
||||||
|
|
Loading…
Add table
Reference in a new issue