[rllib] Importance Sampling and KL Loss for APPO (#5051)

This commit is contained in:
Michael Luo 2019-07-29 15:02:32 -07:00 committed by Richard Liaw
parent 3b00144e7d
commit 1337c98f02
10 changed files with 288 additions and 40 deletions

View file

@ -241,8 +241,9 @@ def add_behaviour_logits(policy):
def validate_config(policy, obs_space, action_space, config):
assert config["batch_mode"] == "truncate_episodes", \
"Must use `truncate_episodes` batch mode with V-trace."
if config["vtrace"]:
assert config["batch_mode"] == "truncate_episodes", \
"Must use `truncate_episodes` batch mode with V-trace."
def choose_optimizer(policy, config):

View file

@ -4,6 +4,7 @@ from __future__ import print_function
from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy
from ray.rllib.agents.trainer import with_base_config
from ray.rllib.agents.ppo.ppo import update_kl
from ray.rllib.agents import impala
# yapf: disable
@ -23,12 +24,17 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
# == PPO surrogate loss options ==
"clip_param": 0.4,
# == PPO KL Loss options ==
"use_kl_loss": False,
"kl_coeff": 1.0,
"kl_target": 0.01,
# == IMPALA optimizer params (see documentation in impala.py) ==
"sample_batch_size": 50,
"train_batch_size": 500,
"min_iter_time_s": 10,
"num_workers": 2,
"num_gpus": 1,
"num_gpus": 0,
"num_data_loader_buffers": 1,
"minibatch_buffer_size": 1,
"num_sgd_iter": 1,
@ -52,8 +58,34 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
# __sphinx_doc_end__
# yapf: enable
def update_target_and_kl(trainer, fetches):
# Update the KL coeff depending on how many steps LearnerThread has stepped
# through
learner_steps = trainer.optimizer.learner.num_steps
if learner_steps >= trainer.target_update_frequency:
# Update Target Network
trainer.optimizer.learner.num_steps = 0
trainer.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
# Also update KL Coeff
if trainer.config["use_kl_loss"]:
update_kl(trainer, trainer.optimizer.learner.stats)
def initialize_target(trainer):
trainer.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
trainer.target_update_frequency = trainer.config["num_sgd_iter"] \
* trainer.config["minibatch_buffer_size"]
APPOTrainer = impala.ImpalaTrainer.with_updates(
name="APPO",
default_config=DEFAULT_CONFIG,
default_policy=AsyncPPOTFPolicy,
get_policy_class=lambda _: AsyncPPOTFPolicy)
get_policy_class=lambda _: AsyncPPOTFPolicy,
after_init=initialize_target,
after_optimizer_step=update_target_and_kl)

View file

@ -12,15 +12,24 @@ import gym
from ray.rllib.agents.impala import vtrace
from ray.rllib.agents.impala.vtrace_policy import _make_time_major, \
BEHAVIOUR_LOGITS, VTraceTFPolicy
BEHAVIOUR_LOGITS, clip_gradients, \
validate_config, choose_optimizer, ValueNetworkMixin
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.utils import try_import_tf
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.agents.ppo.ppo_policy import KLCoeffMixin
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.explained_variance import explained_variance
tf = try_import_tf()
POLICY_SCOPE = "func"
TARGET_POLICY_SCOPE = "target_func"
logger = logging.getLogger(__name__)
@ -36,6 +45,11 @@ class PPOSurrogateLoss(object):
valid_mask: A bool tensor of valid RNN input elements (#2992).
advantages: A float32 tensor of shape [T, B].
value_targets: A float32 tensor of shape [T, B].
vf_loss_coeff (float): Coefficient of the value function loss.
entropy_coeff (float): Coefficient of the entropy regularizer.
clip_param (float): Clip parameter.
cur_kl_coeff (float): Coefficient for KL loss.
use_kl_loss (bool): If true, use KL loss.
"""
def __init__(self,
@ -49,7 +63,11 @@ class PPOSurrogateLoss(object):
value_targets,
vf_loss_coeff=0.5,
entropy_coeff=0.01,
clip_param=0.3):
clip_param=0.3,
cur_kl_coeff=None,
use_kl_loss=False):
def reduce_mean_valid(t):
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
@ -58,32 +76,37 @@ class PPOSurrogateLoss(object):
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
1 + clip_param))
self.mean_kl = tf.reduce_mean(action_kl)
self.pi_loss = -tf.reduce_sum(surrogate_loss)
self.mean_kl = reduce_mean_valid(action_kl)
self.pi_loss = -reduce_mean_valid(surrogate_loss)
# The baseline loss
delta = tf.boolean_mask(values - value_targets, valid_mask)
delta = values - value_targets
self.value_targets = value_targets
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
self.vf_loss = 0.5 * reduce_mean_valid(tf.square(delta))
# The entropy loss
self.entropy = tf.reduce_sum(
tf.boolean_mask(actions_entropy, valid_mask))
self.entropy = reduce_mean_valid(actions_entropy)
# The summed weighted loss
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
self.entropy * entropy_coeff)
# Optional additional KL Loss
if use_kl_loss:
self.total_loss += cur_kl_coeff * self.mean_kl
class VTraceSurrogateLoss(object):
def __init__(self,
actions,
prev_actions_logp,
actions_logp,
old_policy_actions_logp,
action_kl,
actions_entropy,
dones,
behaviour_logits,
old_policy_behaviour_logits,
target_logits,
discount,
rewards,
@ -95,8 +118,10 @@ class VTraceSurrogateLoss(object):
entropy_coeff=0.01,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
clip_param=0.3):
"""PPO surrogate loss with vtrace importance weighting.
clip_param=0.3,
cur_kl_coeff=None,
use_kl_loss=False):
"""APPO Loss, with IS modifications and V-trace for Advantage Estimation
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
batch_size. The reason we need to know `B` is for V-trace to properly
@ -106,10 +131,13 @@ class VTraceSurrogateLoss(object):
actions: An int|float32 tensor of shape [T, B, logit_dim].
prev_actions_logp: A float32 tensor of shape [T, B].
actions_logp: A float32 tensor of shape [T, B].
old_policy_actions_logp: A float32 tensor of shape [T, B].
action_kl: A float32 tensor of shape [T, B].
actions_entropy: A float32 tensor of shape [T, B].
dones: A bool tensor of shape [T, B].
behaviour_logits: A float32 tensor of shape [T, B, logit_dim].
old_policy_behaviour_logits: A float32 tensor of shape
[T, B, logit_dim].
target_logits: A float32 tensor of shape [T, B, logit_dim].
discount: A float32 scalar.
rewards: A float32 tensor of shape [T, B].
@ -117,13 +145,21 @@ class VTraceSurrogateLoss(object):
bootstrap_value: A float32 tensor of shape [B].
dist_class: action distribution class for logits.
valid_mask: A bool tensor of valid RNN input elements (#2992).
vf_loss_coeff (float): Coefficient of the value function loss.
entropy_coeff (float): Coefficient of the entropy regularizer.
clip_param (float): Clip parameter.
cur_kl_coeff (float): Coefficient for KL loss.
use_kl_loss (bool): If true, use KL loss.
"""
def reduce_mean_valid(t):
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
# Compute vtrace on the CPU for better perf.
with tf.device("/cpu:0"):
self.vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=behaviour_logits,
target_policy_logits=target_logits,
target_policy_logits=old_policy_behaviour_logits,
actions=tf.unstack(actions, axis=2),
discounts=tf.to_float(~dones) * discount,
rewards=rewards,
@ -134,7 +170,9 @@ class VTraceSurrogateLoss(object):
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
tf.float32))
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
self.is_ratio = tf.clip_by_value(
tf.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
logp_ratio = self.is_ratio * tf.exp(actions_logp - prev_actions_logp)
advantages = self.vtrace_returns.pg_advantages
surrogate_loss = tf.minimum(
@ -142,22 +180,45 @@ class VTraceSurrogateLoss(object):
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
1 + clip_param))
self.mean_kl = tf.reduce_mean(action_kl)
self.pi_loss = -tf.reduce_sum(surrogate_loss)
self.mean_kl = reduce_mean_valid(action_kl)
self.pi_loss = -reduce_mean_valid(surrogate_loss)
# The baseline loss
delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask)
delta = values - self.vtrace_returns.vs
self.value_targets = self.vtrace_returns.vs
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
self.vf_loss = 0.5 * reduce_mean_valid(tf.square(delta))
# The entropy loss
self.entropy = tf.reduce_sum(
tf.boolean_mask(actions_entropy, valid_mask))
self.entropy = reduce_mean_valid(actions_entropy)
# The summed weighted loss
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
self.entropy * entropy_coeff)
# Optional additional KL Loss
if use_kl_loss:
self.total_loss += cur_kl_coeff * self.mean_kl
def build_appo_model(policy, obs_space, action_space, config):
policy.model = ModelCatalog.get_model_v2(
obs_space,
action_space,
policy.logit_dim,
config["model"],
name=POLICY_SCOPE,
framework="tf")
policy.target_model = ModelCatalog.get_model_v2(
obs_space,
action_space,
policy.logit_dim,
config["model"],
name=TARGET_POLICY_SCOPE,
framework="tf")
return policy.model
def build_appo_surrogate_loss(policy, batch_tensors):
if isinstance(policy.action_space, gym.spaces.Discrete):
@ -177,14 +238,26 @@ def build_appo_surrogate_loss(policy, batch_tensors):
actions = batch_tensors[SampleBatch.ACTIONS]
dones = batch_tensors[SampleBatch.DONES]
rewards = batch_tensors[SampleBatch.REWARDS]
behaviour_logits = batch_tensors[BEHAVIOUR_LOGITS]
policy.target_model_out, _ = policy.target_model(
policy.input_dict, policy.state_in, policy.seq_lens)
old_policy_behaviour_logits = tf.stop_gradient(policy.target_model_out)
unpacked_behaviour_logits = tf.split(
behaviour_logits, output_hidden_shape, axis=1)
unpacked_old_policy_behaviour_logits = tf.split(
old_policy_behaviour_logits, output_hidden_shape, axis=1)
unpacked_outputs = tf.split(policy.model_out, output_hidden_shape, axis=1)
action_dist = policy.action_dist
old_policy_action_dist = policy.dist_class(old_policy_behaviour_logits)
prev_action_dist = policy.dist_class(behaviour_logits)
values = policy.value_function
policy.model_vars = policy.model.variables()
policy.target_model_vars = policy.target_model.variables()
if policy.state_in:
max_seq_len = tf.reduce_max(policy.seq_lens) - 1
mask = tf.sequence_mask(policy.seq_lens, max_seq_len)
@ -199,18 +272,27 @@ def build_appo_surrogate_loss(policy, batch_tensors):
loss_actions = actions if is_multidiscrete else tf.expand_dims(
actions, axis=1)
# Prepare KL for Loss
mean_kl = make_time_major(
old_policy_action_dist.multi_kl(action_dist), drop_last=True)
policy.loss = VTraceSurrogateLoss(
actions=make_time_major(loss_actions, drop_last=True),
prev_actions_logp=make_time_major(
prev_action_dist.logp(actions), drop_last=True),
actions_logp=make_time_major(
action_dist.logp(actions), drop_last=True),
action_kl=prev_action_dist.multi_kl(action_dist),
old_policy_actions_logp=make_time_major(
old_policy_action_dist.logp(actions), drop_last=True),
action_kl=tf.reduce_mean(mean_kl, axis=0)
if is_multidiscrete else mean_kl,
actions_entropy=make_time_major(
action_dist.multi_entropy(), drop_last=True),
dones=make_time_major(dones, drop_last=True),
behaviour_logits=make_time_major(
unpacked_behaviour_logits, drop_last=True),
old_policy_behaviour_logits=make_time_major(
unpacked_old_policy_behaviour_logits, drop_last=True),
target_logits=make_time_major(unpacked_outputs, drop_last=True),
discount=policy.config["gamma"],
rewards=make_time_major(rewards, drop_last=True),
@ -219,17 +301,24 @@ def build_appo_surrogate_loss(policy, batch_tensors):
dist_class=Categorical if is_multidiscrete else policy.dist_class,
valid_mask=make_time_major(mask, drop_last=True),
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.entropy_coeff,
entropy_coeff=policy.config["entropy_coeff"],
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
clip_pg_rho_threshold=policy.config[
"vtrace_clip_pg_rho_threshold"],
clip_param=policy.config["clip_param"])
clip_param=policy.config["clip_param"],
cur_kl_coeff=policy.kl_coeff,
use_kl_loss=policy.config["use_kl_loss"])
else:
logger.info("Using PPO surrogate loss (vtrace=False)")
# Prepare KL for Loss
mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))
policy.loss = PPOSurrogateLoss(
prev_actions_logp=make_time_major(prev_action_dist.logp(actions)),
actions_logp=make_time_major(action_dist.logp(actions)),
action_kl=prev_action_dist.multi_kl(action_dist),
action_kl=tf.reduce_mean(mean_kl, axis=0)
if is_multidiscrete else mean_kl,
actions_entropy=make_time_major(action_dist.multi_entropy()),
values=make_time_major(values),
valid_mask=make_time_major(mask),
@ -238,12 +327,41 @@ def build_appo_surrogate_loss(policy, batch_tensors):
value_targets=make_time_major(
batch_tensors[Postprocessing.VALUE_TARGETS]),
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.entropy_coeff,
clip_param=policy.config["clip_param"])
entropy_coeff=policy.config["entropy_coeff"],
clip_param=policy.config["clip_param"],
cur_kl_coeff=policy.kl_coeff,
use_kl_loss=policy.config["use_kl_loss"])
return policy.loss.total_loss
def stats(policy, batch_tensors):
values_batched = _make_time_major(
policy, policy.value_function, drop_last=policy.config["vtrace"])
stats_dict = {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"policy_loss": policy.loss.pi_loss,
"entropy": policy.loss.entropy,
"var_gnorm": tf.global_norm(policy.var_list),
"vf_loss": policy.loss.vf_loss,
"vf_explained_var": explained_variance(
tf.reshape(policy.loss.value_targets, [-1]),
tf.reshape(values_batched, [-1])),
}
if policy.config["vtrace"]:
is_stat_mean, is_stat_var = tf.nn.moments(policy.loss.is_ratio, [0, 1])
stats_dict.update({"mean_IS": is_stat_mean})
stats_dict.update({"var_IS": is_stat_var})
if policy.config["use_kl_loss"]:
stats_dict.update({"kl": policy.loss.mean_kl})
stats_dict.update({"KL_Coeff": policy.kl_coeff})
return stats_dict
def postprocess_trajectory(policy,
sample_batch,
other_agent_batches=None,
@ -276,8 +394,47 @@ def add_values_and_logits(policy):
return out
AsyncPPOTFPolicy = VTraceTFPolicy.with_updates(
class TargetNetworkMixin(object):
def __init__(self, obs_space, action_space, config):
"""Target Network is updated by the master learner every
trainer.update_target_frequency steps. All worker batches
are importance sampled w.r. to the target network to ensure
a more stable pi_old in PPO.
"""
assign_ops = []
assert len(self.model_vars) == len(self.target_model_vars)
for var, var_target in zip(self.model_vars, self.target_model_vars):
assign_ops.append(var_target.assign(var))
self.update_target_network = tf.group(*assign_ops)
def update_target(self):
return self.get_session().run(self.update_target_network)
def setup_mixins(policy, obs_space, action_space, config):
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
KLCoeffMixin.__init__(policy, config)
ValueNetworkMixin.__init__(policy)
def setup_late_mixins(policy, obs_space, action_space, config):
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
AsyncPPOTFPolicy = build_tf_policy(
name="AsyncPPOTFPolicy",
make_model=build_appo_model,
loss_fn=build_appo_surrogate_loss,
stats_fn=stats,
postprocess_fn=postprocess_trajectory,
extra_action_fetches_fn=add_values_and_logits)
optimizer_fn=choose_optimizer,
gradients_fn=clip_gradients,
extra_action_fetches_fn=add_values_and_logits,
before_init=validate_config,
before_loss_init=setup_mixins,
after_init=setup_late_mixins,
mixins=[
LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin,
ValueNetworkMixin
],
get_batch_divisibility_req=lambda p: p.config["sample_batch_size"])

View file

@ -49,7 +49,8 @@ class LearnerThread(threading.Thread):
inqueue=self.inqueue,
size=minibatch_buffer_size,
timeout=learner_queue_timeout,
num_passes=num_sgd_iter)
num_passes=num_sgd_iter,
init_num_passes=num_sgd_iter)
self.queue_timer = TimerStat()
self.grad_timer = TimerStat()
self.load_timer = TimerStat()
@ -58,6 +59,7 @@ class LearnerThread(threading.Thread):
self.weights_updated = False
self.stats = {}
self.stopped = False
self.num_steps = 0
def run(self):
while not self.stopped:
@ -72,5 +74,6 @@ class LearnerThread(threading.Thread):
self.weights_updated = True
self.stats = get_learner_stats(fetches)
self.num_steps += 1
self.outqueue.put(batch.count)
self.learner_queue_size.push(self.inqueue.qsize())

View file

@ -11,7 +11,7 @@ class MinibatchBuffer(object):
This is for use with AsyncSamplesOptimizer.
"""
def __init__(self, inqueue, size, timeout, num_passes):
def __init__(self, inqueue, size, timeout, num_passes, init_num_passes=1):
"""Initialize a minibatch buffer.
Arguments:
@ -19,12 +19,13 @@ class MinibatchBuffer(object):
size: Max number of data items to buffer.
timeout: Queue timeout
num_passes: Max num times each data item should be emitted.
"""
init_num_passes: Initial max passes for each data item
"""
self.inqueue = inqueue
self.size = size
self.timeout = timeout
self.max_ttl = num_passes
self.cur_max_ttl = 1 # ramp up slowly to better mix the input data
self.cur_max_ttl = init_num_passes
self.buffers = [None] * size
self.ttl = [0] * size
self.idx = 0

View file

@ -131,6 +131,8 @@ class DynamicTFPolicy(TFPolicy):
else:
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.logit_dim = logit_dim
if existing_model:
self.model = existing_model
elif make_model:

View file

@ -430,9 +430,11 @@ class TFPolicy(Policy):
builder.add_feed_dict({self._obs_input: obs_batch})
if state_batches:
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
if self._prev_action_input is not None and prev_action_batch:
if self._prev_action_input is not None and \
prev_action_batch is not None:
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
if self._prev_reward_input is not None and prev_reward_batch:
if self._prev_reward_input is not None and \
prev_reward_batch is not None:
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
builder.add_feed_dict({self._is_training: False})
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))

View file

@ -0,0 +1,35 @@
# This can reach 9k reward in 2 hours on a Titan XP GPU
# with 16 workers and 8 envs per worker.
halfcheetah-appo:
env: HalfCheetah-v2
run: APPO
stop:
time_total_s: 10800
config:
vtrace: True
gamma: 0.99
lambda: 0.95
sample_batch_size: 512
train_batch_size: 4096
num_workers: 16
num_gpus: 1
broadcast_interval: 1
max_sample_requests_in_flight_per_worker: 1
num_data_loader_buffers: 1
num_envs_per_worker: 32
minibatch_buffer_size: 16
num_sgd_iter: 32
clip_param: 0.2
lr_schedule: [
[0, 0.0005],
[150000000, 0.000001],
]
vf_loss_coeff: 0.5
entropy_coeff: 0.01
grad_clip: 0.5
batch_mode: truncate_episodes
use_kl_loss: True
kl_coeff: 1.0
kl_target: 0.04
observation_filter: MeanStdFilter

View file

@ -1,3 +1,8 @@
# This can reach 18-19 reward in ~5-7 minutes on a Titan XP GPU
# with 32 workers and 8 envs per worker. IMPALA, when ran with
# similar configurations, solved Pong in 10-12 minutes.
# APPO can also solve Pong in 2.5 million timesteps, which is
# 2x more efficient than that of IMPALA.
pong-appo:
env: PongNoFrameskip-v4
run: APPO
@ -5,13 +10,15 @@ pong-appo:
episode_reward_mean: 18.0
timesteps_total: 5000000
config:
vtrace: True
use_kl_loss: False
sample_batch_size: 50
train_batch_size: 750
num_workers: 32
broadcast_interval: 1
max_sample_requests_in_flight_per_worker: 1
num_data_loader_buffers: 1
num_envs_per_worker: 5
num_envs_per_worker: 8
minibatch_buffer_size: 4
num_sgd_iter: 2
vf_loss_coeff: 1.0

View file

@ -2,11 +2,19 @@ pendulum-appo-vt:
env: Pendulum-v0
run: APPO
stop:
episode_reward_mean: -900 # just check it learns a bit
episode_reward_mean: -1200 # just check it learns a bit
timesteps_total: 500000
config:
vtrace: False
num_gpus: 0
num_workers: 1
lambda: 0.1
gamma: 0.95
train_batch_size: 50
vtrace: true
lr: 0.0003
train_batch_size: 100
minibatch_buffer_size: 16
num_sgd_iter: 10
model:
fcnet_hiddens: [64, 64]
batch_mode: complete_episodes
observation_filter: MeanStdFilter