[rllib] Revert [rllib] Port DDPG to the build_tf_policy pattern (#5626)

This commit is contained in:
Eric Liang 2019-09-04 21:39:22 -07:00 committed by Philipp Moritz
parent 1823ea74e3
commit dcff263ce9
6 changed files with 658 additions and 499 deletions

View file

@ -54,7 +54,7 @@ PICKLE_OBJECT_WARNING_SIZE = 10**7
# The maximum resource quantity that is allowed. TODO(rkn): This could be
# relaxed, but the current implementation of the node manager will be slower
# for large resource quantities due to bookkeeping of specific resource IDs.
MAX_RESOURCE_QUANTITY = 10000
MAX_RESOURCE_QUANTITY = 20000
# Each memory "resource" counts as this many bytes of memory.
MEMORY_RESOURCE_UNIT_BYTES = 50 * 1024 * 1024

View file

@ -41,7 +41,7 @@ DEFAULT_CONFIG = with_common_config({
# === Model ===
# Apply a state preprocessor with spec given by the "model" config option
# (like other RL algorithms). This is mostly useful if you have a weird
# observation shape, like an image. Auto-enabled if a custom model is set.
# observation shape, like an image. Disabled by default.
"use_state_preprocessor": False,
# Postprocess the policy network model output with these hidden layers. If
# use_state_preprocessor is False, then these will be the *only* hidden
@ -173,7 +173,7 @@ def make_exploration_schedule(config, worker_index):
if config["per_worker_exploration"]:
assert config["num_workers"] > 1, "This requires multiple workers"
if worker_index >= 0:
# Exploration constants from the Ape-X paper
# FIXME: what do magic constants mean? (0.4, 7)
max_index = float(config["num_workers"] - 1)
exponent = 1 + worker_index / max_index * 7
return ConstantSchedule(0.4**exponent)

View file

@ -4,37 +4,80 @@ from __future__ import print_function
from gym.spaces import Box
import numpy as np
import logging
import ray
import ray.experimental.tf_utils
from ray.rllib.agents.ddpg.ddpg_model import DDPGModel
from ray.rllib.agents.ddpg.noop_model import NoopModel
from ray.rllib.agents.dqn.dqn_policy import _postprocess_dqn, PRIO_WEIGHTS
from ray.rllib.agents.dqn.dqn_policy import _postprocess_dqn
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.tf_ops import huber_loss, minimize_and_clip, \
make_tf_callable
from ray.rllib.utils.tf_ops import huber_loss, minimize_and_clip, scope_vars
tf = try_import_tf()
logger = logging.getLogger(__name__)
ACTION_SCOPE = "action"
POLICY_SCOPE = "policy"
POLICY_TARGET_SCOPE = "target_policy"
Q_SCOPE = "critic"
Q_TARGET_SCOPE = "target_critic"
TWIN_Q_SCOPE = "twin_critic"
TWIN_Q_TARGET_SCOPE = "twin_target_critic"
# Importance sampling weights for prioritized replay
PRIO_WEIGHTS = "weights"
def build_ddpg_model(policy, obs_space, action_space, config):
if config["model"]["custom_model"]:
logger.warning(
"Setting use_state_preprocessor=True since a custom model "
"was specified.")
config["use_state_preprocessor"] = True
class DDPGPostprocessing(object):
"""Implements n-step learning and param noise adjustments."""
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
if self.config["parameter_noise"]:
# adjust the sigma of parameter space noise
states, noisy_actions = [
list(x) for x in sample_batch.columns(
[SampleBatch.CUR_OBS, SampleBatch.ACTIONS])
]
self.sess.run(self.remove_noise_op)
clean_actions = self.sess.run(
self.output_actions,
feed_dict={
self.cur_observations: states,
self.stochastic: False,
self.noise_scale: .0,
self.pure_exploration_phase: False,
})
distance_in_action_space = np.sqrt(
np.mean(np.square(clean_actions - noisy_actions)))
self.pi_distance = distance_in_action_space
if distance_in_action_space < \
self.config["exploration_ou_sigma"] * self.cur_noise_scale:
# multiplying the sampled OU noise by noise scale is
# equivalent to multiplying the sigma of OU by noise scale
self.parameter_noise_sigma_val *= 1.01
else:
self.parameter_noise_sigma_val /= 1.01
self.parameter_noise_sigma.load(
self.parameter_noise_sigma_val, session=self.sess)
return _postprocess_dqn(self, sample_batch)
class DDPGTFPolicy(DDPGPostprocessing, TFPolicy):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, **config)
if not isinstance(action_space, Box):
raise UnsupportedSpaceException(
"Action space {} is not supported for DDPG.".format(action_space))
"Action space {} is not supported for DDPG.".format(
action_space))
if len(action_space.shape) > 1:
raise UnsupportedSpaceException(
"Action space has multiple dimensions "
@ -42,65 +85,369 @@ def build_ddpg_model(policy, obs_space, action_space, config):
"Consider reshaping this into a single dimension, "
"using a Tuple action space, or the multi-agent API.")
if config["use_state_preprocessor"]:
default_model = None # catalog decides
num_outputs = 256 # arbitrary
config["model"]["no_final_linear"] = True
self.config = config
self.cur_noise_scale = 1.0
self.cur_pure_exploration_phase = False
self.dim_actions = action_space.shape[0]
self.low_action = action_space.low
self.high_action = action_space.high
# create global step for counting the number of update operations
self.global_step = tf.train.get_or_create_global_step()
# use separate optimizers for actor & critic
self._actor_optimizer = tf.train.AdamOptimizer(
learning_rate=self.config["actor_lr"])
self._critic_optimizer = tf.train.AdamOptimizer(
learning_rate=self.config["critic_lr"])
# Action inputs
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
self.noise_scale = tf.placeholder(tf.float32, (), name="noise_scale")
self.pure_exploration_phase = tf.placeholder(
tf.bool, (), name="pure_exploration_phase")
self.cur_observations = tf.placeholder(
tf.float32,
shape=(None, ) + observation_space.shape,
name="cur_obs")
with tf.variable_scope(POLICY_SCOPE) as scope:
policy_out, self.policy_model = self._build_policy_network(
self.cur_observations, observation_space, action_space)
self.policy_vars = scope_vars(scope.name)
# Noise vars for P network except for layer normalization vars
if self.config["parameter_noise"]:
self._build_parameter_noise([
var for var in self.policy_vars if "LayerNorm" not in var.name
])
# Action outputs
with tf.variable_scope(ACTION_SCOPE):
self.output_actions = self._add_exploration_noise(
policy_out, self.stochastic, self.noise_scale,
self.pure_exploration_phase, action_space)
if self.config["smooth_target_policy"]:
self.reset_noise_op = tf.no_op()
else:
default_model = NoopModel
num_outputs = int(np.product(obs_space.shape))
with tf.variable_scope(ACTION_SCOPE, reuse=True):
exploration_sample = tf.get_variable(name="ornstein_uhlenbeck")
self.reset_noise_op = tf.assign(exploration_sample,
self.dim_actions * [.0])
policy.model = ModelCatalog.get_model_v2(
obs_space,
# Replay inputs
self.obs_t = tf.placeholder(
tf.float32,
shape=(None, ) + observation_space.shape,
name="observation")
self.act_t = tf.placeholder(
tf.float32, shape=(None, ) + action_space.shape, name="action")
self.rew_t = tf.placeholder(tf.float32, [None], name="reward")
self.obs_tp1 = tf.placeholder(
tf.float32, shape=(None, ) + observation_space.shape)
self.done_mask = tf.placeholder(tf.float32, [None], name="done")
self.importance_weights = tf.placeholder(
tf.float32, [None], name="weight")
# policy network evaluation
with tf.variable_scope(POLICY_SCOPE, reuse=True) as scope:
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
self.policy_t, _ = self._build_policy_network(
self.obs_t, observation_space, action_space)
policy_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
prev_update_ops)
# target policy network evaluation
with tf.variable_scope(POLICY_TARGET_SCOPE) as scope:
policy_tp1, _ = self._build_policy_network(
self.obs_tp1, observation_space, action_space)
target_policy_vars = scope_vars(scope.name)
# Action outputs
with tf.variable_scope(ACTION_SCOPE, reuse=True):
if config["smooth_target_policy"]:
target_noise_clip = self.config["target_noise_clip"]
clipped_normal_sample = tf.clip_by_value(
tf.random_normal(
tf.shape(policy_tp1),
stddev=self.config["target_noise"]),
-target_noise_clip, target_noise_clip)
policy_tp1_smoothed = tf.clip_by_value(
policy_tp1 + clipped_normal_sample,
action_space.low * tf.ones_like(policy_tp1),
action_space.high * tf.ones_like(policy_tp1))
else:
# no smoothing, just use deterministic actions
policy_tp1_smoothed = policy_tp1
# q network evaluation
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
with tf.variable_scope(Q_SCOPE) as scope:
# Q-values for given actions & observations in given current
q_t, self.q_model = self._build_q_network(
self.obs_t, observation_space, action_space, self.act_t)
self.q_func_vars = scope_vars(scope.name)
self.stats = {
"mean_q": tf.reduce_mean(q_t),
"max_q": tf.reduce_max(q_t),
"min_q": tf.reduce_min(q_t),
}
with tf.variable_scope(Q_SCOPE, reuse=True):
# Q-values for current policy (no noise) in given current state
q_t_det_policy, _ = self._build_q_network(
self.obs_t, observation_space, action_space, self.policy_t)
if self.config["twin_q"]:
with tf.variable_scope(TWIN_Q_SCOPE) as scope:
twin_q_t, self.twin_q_model = self._build_q_network(
self.obs_t, observation_space, action_space, self.act_t)
self.twin_q_func_vars = scope_vars(scope.name)
q_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
# target q network evaluation
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
q_tp1, _ = self._build_q_network(self.obs_tp1, observation_space,
action_space, policy_tp1_smoothed)
target_q_func_vars = scope_vars(scope.name)
if self.config["twin_q"]:
with tf.variable_scope(TWIN_Q_TARGET_SCOPE) as scope:
twin_q_tp1, _ = self._build_q_network(
self.obs_tp1, observation_space, action_space,
policy_tp1_smoothed)
twin_target_q_func_vars = scope_vars(scope.name)
if self.config["twin_q"]:
self.critic_loss, self.actor_loss, self.td_error \
= self._build_actor_critic_loss(
q_t, q_tp1, q_t_det_policy, twin_q_t=twin_q_t,
twin_q_tp1=twin_q_tp1)
else:
self.critic_loss, self.actor_loss, self.td_error \
= self._build_actor_critic_loss(
q_t, q_tp1, q_t_det_policy)
if config["l2_reg"] is not None:
for var in self.policy_vars:
if "bias" not in var.name:
self.actor_loss += (config["l2_reg"] * tf.nn.l2_loss(var))
for var in self.q_func_vars:
if "bias" not in var.name:
self.critic_loss += (config["l2_reg"] * tf.nn.l2_loss(var))
if self.config["twin_q"]:
for var in self.twin_q_func_vars:
if "bias" not in var.name:
self.critic_loss += (
config["l2_reg"] * tf.nn.l2_loss(var))
# update_target_fn will be called periodically to copy Q network to
# target Q network
self.tau_value = config.get("tau")
self.tau = tf.placeholder(tf.float32, (), name="tau")
update_target_expr = []
for var, var_target in zip(
sorted(self.q_func_vars, key=lambda v: v.name),
sorted(target_q_func_vars, key=lambda v: v.name)):
update_target_expr.append(
var_target.assign(self.tau * var +
(1.0 - self.tau) * var_target))
if self.config["twin_q"]:
for var, var_target in zip(
sorted(self.twin_q_func_vars, key=lambda v: v.name),
sorted(twin_target_q_func_vars, key=lambda v: v.name)):
update_target_expr.append(
var_target.assign(self.tau * var +
(1.0 - self.tau) * var_target))
for var, var_target in zip(
sorted(self.policy_vars, key=lambda v: v.name),
sorted(target_policy_vars, key=lambda v: v.name)):
update_target_expr.append(
var_target.assign(self.tau * var +
(1.0 - self.tau) * var_target))
self.update_target_expr = tf.group(*update_target_expr)
self.sess = tf.get_default_session()
self.loss_inputs = [
(SampleBatch.CUR_OBS, self.obs_t),
(SampleBatch.ACTIONS, self.act_t),
(SampleBatch.REWARDS, self.rew_t),
(SampleBatch.NEXT_OBS, self.obs_tp1),
(SampleBatch.DONES, self.done_mask),
(PRIO_WEIGHTS, self.importance_weights),
]
input_dict = dict(self.loss_inputs)
if self.config["use_state_preprocessor"]:
# Model self-supervised losses
self.actor_loss = self.policy_model.custom_loss(
self.actor_loss, input_dict)
self.critic_loss = self.q_model.custom_loss(
self.critic_loss, input_dict)
if self.config["twin_q"]:
self.critic_loss = self.twin_q_model.custom_loss(
self.critic_loss, input_dict)
TFPolicy.__init__(
self,
observation_space,
action_space,
num_outputs,
config["model"],
framework="tf",
model_interface=DDPGModel,
default_model=default_model,
name="ddpg_model",
actor_hidden_activation=config["actor_hidden_activation"],
actor_hiddens=config["actor_hiddens"],
critic_hidden_activation=config["critic_hidden_activation"],
critic_hiddens=config["critic_hiddens"],
parameter_noise=config["parameter_noise"],
twin_q=config["twin_q"])
self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions,
loss=self.actor_loss + self.critic_loss,
loss_inputs=self.loss_inputs,
update_ops=q_batchnorm_update_ops + policy_batchnorm_update_ops)
self.sess.run(tf.global_variables_initializer())
policy.target_model = ModelCatalog.get_model_v2(
obs_space,
action_space,
num_outputs,
config["model"],
framework="tf",
model_interface=DDPGModel,
default_model=default_model,
name="target_ddpg_model",
actor_hidden_activation=config["actor_hidden_activation"],
actor_hiddens=config["actor_hiddens"],
critic_hidden_activation=config["critic_hidden_activation"],
critic_hiddens=config["critic_hiddens"],
parameter_noise=config["parameter_noise"],
twin_q=config["twin_q"])
# Note that this encompasses both the policy and Q-value networks and
# their corresponding target networks
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
tf.group(q_t_det_policy, q_tp1), self.sess)
return policy.model
# Hard initial update
self.update_target(tau=1.0)
@override(TFPolicy)
def optimizer(self):
# we don't use this because we have two separate optimisers
return None
def postprocess_trajectory(policy,
sample_batch,
other_agent_batches=None,
episode=None):
if policy.config["parameter_noise"]:
policy.adjust_param_noise_sigma(sample_batch)
return _postprocess_dqn(policy, sample_batch)
@override(TFPolicy)
def build_apply_op(self, optimizer, grads_and_vars):
# for policy gradient, update policy net one time v.s.
# update critic net `policy_delay` time(s)
should_apply_actor_opt = tf.equal(
tf.mod(self.global_step, self.config["policy_delay"]), 0)
def make_apply_op():
return self._actor_optimizer.apply_gradients(
self._actor_grads_and_vars)
def build_action_output(policy, model, input_dict, obs_space, action_space,
config):
model_out, _ = model({
"obs": input_dict[SampleBatch.CUR_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
action_out = model.get_policy_output(model_out)
actor_op = tf.cond(
should_apply_actor_opt,
true_fn=make_apply_op,
false_fn=lambda: tf.no_op())
critic_op = self._critic_optimizer.apply_gradients(
self._critic_grads_and_vars)
# increment global step & apply ops
with tf.control_dependencies([tf.assign_add(self.global_step, 1)]):
return tf.group(actor_op, critic_op)
@override(TFPolicy)
def gradients(self, optimizer, loss):
if self.config["grad_norm_clipping"] is not None:
actor_grads_and_vars = minimize_and_clip(
self._actor_optimizer,
self.actor_loss,
var_list=self.policy_vars,
clip_val=self.config["grad_norm_clipping"])
critic_grads_and_vars = minimize_and_clip(
self._critic_optimizer,
self.critic_loss,
var_list=self.q_func_vars + self.twin_q_func_vars
if self.config["twin_q"] else self.q_func_vars,
clip_val=self.config["grad_norm_clipping"])
else:
actor_grads_and_vars = self._actor_optimizer.compute_gradients(
self.actor_loss, var_list=self.policy_vars)
if self.config["twin_q"]:
critic_vars = self.q_func_vars + self.twin_q_func_vars
else:
critic_vars = self.q_func_vars
critic_grads_and_vars = self._critic_optimizer.compute_gradients(
self.critic_loss, var_list=critic_vars)
# save these for later use in build_apply_op
self._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
if g is not None]
self._critic_grads_and_vars = [(g, v)
for (g, v) in critic_grads_and_vars
if g is not None]
grads_and_vars = self._actor_grads_and_vars \
+ self._critic_grads_and_vars
return grads_and_vars
@override(TFPolicy)
def extra_compute_action_feed_dict(self):
return {
# FIXME: what about turning off exploration? Isn't that a good
# idea?
self.stochastic: True,
self.noise_scale: self.cur_noise_scale,
self.pure_exploration_phase: self.cur_pure_exploration_phase,
}
@override(TFPolicy)
def extra_compute_grad_fetches(self):
return {
"td_error": self.td_error,
LEARNER_STATS_KEY: self.stats,
}
@override(TFPolicy)
def get_weights(self):
return self.variables.get_weights()
@override(TFPolicy)
def set_weights(self, weights):
self.variables.set_weights(weights)
@override(Policy)
def get_state(self):
return [
TFPolicy.get_state(self), self.cur_noise_scale,
self.cur_pure_exploration_phase
]
@override(Policy)
def set_state(self, state):
TFPolicy.set_state(self, state[0])
self.set_epsilon(state[1])
self.set_pure_exploration_phase(state[2])
def _build_q_network(self, obs, obs_space, action_space, actions):
if self.config["use_state_preprocessor"]:
q_model = ModelCatalog.get_model({
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, action_space, 1, self.config["model"])
q_out = tf.concat([q_model.last_layer, actions], axis=1)
else:
q_model = None
q_out = tf.concat([obs, actions], axis=1)
activation = getattr(tf.nn, self.config["critic_hidden_activation"])
for hidden in self.config["critic_hiddens"]:
q_out = tf.layers.dense(q_out, units=hidden, activation=activation)
q_values = tf.layers.dense(q_out, units=1, activation=None)
return q_values, q_model
def _build_policy_network(self, obs, obs_space, action_space):
if self.config["use_state_preprocessor"]:
model = ModelCatalog.get_model({
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, action_space, 1, self.config["model"])
action_out = model.last_layer
else:
model = None
action_out = obs
activation = getattr(tf.nn, self.config["actor_hidden_activation"])
for hidden in self.config["actor_hiddens"]:
if self.config["parameter_noise"]:
import tensorflow.contrib.layers as layers
action_out = layers.fully_connected(
action_out,
num_outputs=hidden,
activation_fn=activation,
normalizer_fn=layers.layer_norm)
else:
action_out = tf.layers.dense(
action_out, units=hidden, activation=activation)
action_out = tf.layers.dense(
action_out, units=self.dim_actions, activation=None)
# Use sigmoid to scale to [0,1], but also double magnitude of input to
# emulate behaviour of tanh activation used in DDPG and TD3 papers.
@ -110,9 +457,14 @@ def build_action_output(policy, model, input_dict, obs_space, action_space,
# get same dims)
action_range = (action_space.high - action_space.low)[None]
low_action = action_space.low[None]
deterministic_actions = action_range * sigmoid_out + low_action
actions = action_range * sigmoid_out + low_action
noise_type = config["exploration_noise_type"]
return actions, model
def _add_exploration_noise(self, deterministic_actions,
should_be_stochastic, noise_scale,
enable_pure_exploration, action_space):
noise_type = self.config["exploration_noise_type"]
action_low = action_space.low
action_high = action_space.high
action_range = action_space.high - action_low
@ -122,9 +474,9 @@ def build_action_output(policy, model, input_dict, obs_space, action_space,
# shape of deterministic_actions is [None, dim_action]
if noise_type == "gaussian":
# add IID Gaussian noise for exploration, TD3-style
normal_sample = policy.noise_scale * tf.random_normal(
normal_sample = noise_scale * tf.random_normal(
tf.shape(deterministic_actions),
stddev=config["exploration_gaussian_sigma"])
stddev=self.config["exploration_gaussian_sigma"])
stochastic_actions = tf.clip_by_value(
deterministic_actions + normal_sample,
action_low * tf.ones_like(deterministic_actions),
@ -139,12 +491,13 @@ def build_action_output(policy, model, input_dict, obs_space, action_space,
trainable=False)
normal_sample = tf.random_normal(
shape=[action_low.size], mean=0.0, stddev=1.0)
ou_new = config["exploration_ou_theta"] \
ou_new = self.config["exploration_ou_theta"] \
* -exploration_sample \
+ config["exploration_ou_sigma"] * normal_sample
exploration_value = tf.assign_add(exploration_sample, ou_new)
base_scale = config["exploration_ou_noise_scale"]
noise = policy.noise_scale * base_scale \
+ self.config["exploration_ou_sigma"] * normal_sample
exploration_value = tf.assign_add(exploration_sample,
ou_new)
base_scale = self.config["exploration_ou_noise_scale"]
noise = noise_scale * base_scale \
* exploration_value * action_range
stochastic_actions = tf.clip_by_value(
deterministic_actions + noise,
@ -172,233 +525,126 @@ def build_action_output(policy, model, input_dict, obs_space, action_space,
# noise_scale is how a worker signals no noise should be used
# (this is ugly and should be fixed by adding an "eval_mode"
# config flag or something)
tf.logical_and(policy.pure_exploration_phase,
policy.noise_scale > 0),
tf.logical_and(enable_pure_exploration, noise_scale > 0),
true_fn=make_uniform_random_actions,
false_fn=make_noisy_actions)
return stochastic_actions
enable_stochastic = tf.logical_and(policy.stochastic,
not config["parameter_noise"])
enable_stochastic = tf.logical_and(should_be_stochastic,
not self.config["parameter_noise"])
actions = tf.cond(enable_stochastic, compute_stochastic_actions,
lambda: deterministic_actions)
policy.output_actions = actions
return actions, None
return actions
def actor_critic_loss(policy, model, _, train_batch):
model_out_t, _ = model({
"obs": train_batch[SampleBatch.CUR_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
model_out_tp1, _ = model({
"obs": train_batch[SampleBatch.NEXT_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
target_model_out_tp1, _ = policy.target_model({
"obs": train_batch[SampleBatch.NEXT_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
policy_t = model.get_policy_output(model_out_t)
policy_tp1 = model.get_policy_output(model_out_tp1)
if policy.config["smooth_target_policy"]:
target_noise_clip = policy.config["target_noise_clip"]
clipped_normal_sample = tf.clip_by_value(
tf.random_normal(
tf.shape(policy_tp1), stddev=policy.config["target_noise"]),
-target_noise_clip, target_noise_clip)
policy_tp1_smoothed = tf.clip_by_value(
policy_tp1 + clipped_normal_sample,
policy.action_space.low * tf.ones_like(policy_tp1),
policy.action_space.high * tf.ones_like(policy_tp1))
else:
policy_tp1_smoothed = policy_tp1
# q network evaluation
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
if policy.config["twin_q"]:
twin_q_t = model.get_twin_q_values(model_out_t,
train_batch[SampleBatch.ACTIONS])
# Q-values for current policy (no noise) in given current state
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
# target q network evaluation
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
policy_tp1_smoothed)
if policy.config["twin_q"]:
twin_q_tp1 = policy.target_model.get_twin_q_values(
target_model_out_tp1, policy_tp1_smoothed)
def _build_actor_critic_loss(self,
q_t,
q_tp1,
q_t_det_policy,
twin_q_t=None,
twin_q_tp1=None):
twin_q = self.config["twin_q"]
gamma = self.config["gamma"]
n_step = self.config["n_step"]
use_huber = self.config["use_huber"]
huber_threshold = self.config["huber_threshold"]
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
if policy.config["twin_q"]:
if twin_q:
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
q_tp1_best_masked = (
1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * q_tp1_best
q_tp1_best_masked = (1.0 - self.done_mask) * q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = tf.stop_gradient(
train_batch[SampleBatch.REWARDS] +
policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked)
self.rew_t + gamma**n_step * q_tp1_best_masked)
# compute the error (potentially clipped)
if policy.config["twin_q"]:
if twin_q:
td_error = q_t_selected - q_t_selected_target
twin_td_error = twin_q_t_selected - q_t_selected_target
td_error = td_error + twin_td_error
if policy.config["use_huber"]:
errors = huber_loss(td_error, policy.config["huber_threshold"]) \
+ huber_loss(twin_td_error, policy.config["huber_threshold"])
if use_huber:
errors = huber_loss(td_error, huber_threshold) \
+ huber_loss(twin_td_error, huber_threshold)
else:
errors = 0.5 * tf.square(td_error) + 0.5 * tf.square(twin_td_error)
errors = 0.5 * tf.square(td_error) + 0.5 * tf.square(
twin_td_error)
else:
td_error = q_t_selected - q_t_selected_target
if policy.config["use_huber"]:
errors = huber_loss(td_error, policy.config["huber_threshold"])
if use_huber:
errors = huber_loss(td_error, huber_threshold)
else:
errors = 0.5 * tf.square(td_error)
critic_loss = model.custom_loss(
tf.reduce_mean(
tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) * errors),
train_batch)
critic_loss = tf.reduce_mean(self.importance_weights * errors)
actor_loss = -tf.reduce_mean(q_t_det_policy)
return critic_loss, actor_loss, td_error
if policy.config["l2_reg"] is not None:
for var in model.policy_variables():
if "bias" not in var.name:
actor_loss += policy.config["l2_reg"] * tf.nn.l2_loss(var)
for var in model.q_variables():
if "bias" not in var.name:
critic_loss += policy.config["l2_reg"] * tf.nn.l2_loss(var)
# save for stats function
policy.q_t = q_t
policy.td_error = td_error
policy.actor_loss = actor_loss
policy.critic_loss = critic_loss
# in a custom apply op we handle the losses separately, but return them
# combined in one loss for now
return actor_loss + critic_loss
def gradients(policy, optimizer, loss):
if policy.config["grad_norm_clipping"] is not None:
actor_grads_and_vars = minimize_and_clip(
optimizer,
policy.actor_loss,
var_list=policy.model.policy_variables(),
clip_val=policy.config["grad_norm_clipping"])
critic_grads_and_vars = minimize_and_clip(
optimizer,
policy.critic_loss,
var_list=policy.model.q_variables(),
clip_val=policy.config["grad_norm_clipping"])
else:
actor_grads_and_vars = optimizer.compute_gradients(
policy.actor_loss, var_list=policy.model.policy_variables())
critic_grads_and_vars = optimizer.compute_gradients(
policy.critic_loss, var_list=policy.model.q_variables())
# save these for later use in build_apply_op
policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
if g is not None]
policy._critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars
if g is not None]
grads_and_vars = (
policy._actor_grads_and_vars + policy._critic_grads_and_vars)
return grads_and_vars
def apply_gradients(policy, optimizer, grads_and_vars):
# for policy gradient, update policy net one time v.s.
# update critic net `policy_delay` time(s)
should_apply_actor_opt = tf.equal(
tf.mod(policy.global_step, policy.config["policy_delay"]), 0)
def make_apply_op():
return policy._actor_optimizer.apply_gradients(
policy._actor_grads_and_vars)
actor_op = tf.cond(
should_apply_actor_opt,
true_fn=make_apply_op,
false_fn=lambda: tf.no_op())
critic_op = policy._critic_optimizer.apply_gradients(
policy._critic_grads_and_vars)
# increment global step & apply ops
with tf.control_dependencies([tf.assign_add(policy.global_step, 1)]):
return tf.group(actor_op, critic_op)
def stats(policy, train_batch):
return {
"td_error": tf.reduce_mean(policy.td_error),
"actor_loss": tf.reduce_mean(policy.actor_loss),
"critic_loss": tf.reduce_mean(policy.critic_loss),
"mean_q": tf.reduce_mean(policy.q_t),
"max_q": tf.reduce_max(policy.q_t),
"min_q": tf.reduce_min(policy.q_t),
}
class ExplorationStateMixin(object):
def __init__(self, obs_space, action_space, config):
self.cur_noise_scale = 1.0
self.cur_pure_exploration_phase = False
self.stochastic = tf.get_variable(
initializer=tf.constant_initializer(True),
name="stochastic",
shape=(),
trainable=False,
dtype=tf.bool)
self.noise_scale = tf.get_variable(
initializer=tf.constant_initializer(self.cur_noise_scale),
name="noise_scale",
def _build_parameter_noise(self, pnet_params):
self.parameter_noise_sigma_val = self.config["exploration_ou_sigma"]
self.parameter_noise_sigma = tf.get_variable(
initializer=tf.constant_initializer(
self.parameter_noise_sigma_val),
name="parameter_noise_sigma",
shape=(),
trainable=False,
dtype=tf.float32)
self.pure_exploration_phase = tf.get_variable(
initializer=tf.constant_initializer(
self.cur_pure_exploration_phase),
name="pure_exploration_phase",
shape=(),
trainable=False,
dtype=tf.bool)
self.parameter_noise = list()
# No need to add any noise on LayerNorm parameters
for var in pnet_params:
noise_var = tf.get_variable(
name=var.name.split(":")[0] + "_noise",
shape=var.shape,
initializer=tf.constant_initializer(.0),
trainable=False)
self.parameter_noise.append(noise_var)
remove_noise_ops = list()
for var, var_noise in zip(pnet_params, self.parameter_noise):
remove_noise_ops.append(tf.assign_add(var, -var_noise))
self.remove_noise_op = tf.group(*tuple(remove_noise_ops))
generate_noise_ops = list()
for var_noise in self.parameter_noise:
generate_noise_ops.append(
tf.assign(
var_noise,
tf.random_normal(
shape=var_noise.shape,
stddev=self.parameter_noise_sigma)))
with tf.control_dependencies(generate_noise_ops):
add_noise_ops = list()
for var, var_noise in zip(pnet_params, self.parameter_noise):
add_noise_ops.append(tf.assign_add(var, var_noise))
self.add_noise_op = tf.group(*tuple(add_noise_ops))
self.pi_distance = None
def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
td_err = self.sess.run(
self.td_error,
feed_dict={
self.obs_t: [np.array(ob) for ob in obs_t],
self.act_t: act_t,
self.rew_t: rew_t,
self.obs_tp1: [np.array(ob) for ob in obs_tp1],
self.done_mask: done_mask,
self.importance_weights: importance_weights
})
return td_err
def reset_noise(self, sess):
sess.run(self.reset_noise_op)
def add_parameter_noise(self):
if self.config["parameter_noise"]:
self.get_session().run(self.model.add_noise_op)
self.sess.run(self.add_noise_op)
def adjust_param_noise_sigma(self, sample_batch):
assert not tf.executing_eagerly(), "eager not supported with p noise"
# adjust the sigma of parameter space noise
states, noisy_actions = [
list(x) for x in sample_batch.columns(
[SampleBatch.CUR_OBS, SampleBatch.ACTIONS])
]
self.get_session().run(self.model.remove_noise_op)
clean_actions = self.get_session().run(
self.output_actions,
feed_dict={
self.get_placeholder(SampleBatch.CUR_OBS): states,
self.stochastic: False,
self.noise_scale: .0,
self.pure_exploration_phase: False,
})
distance_in_action_space = np.sqrt(
np.mean(np.square(clean_actions - noisy_actions)))
self.model.update_action_noise(
self.get_session(), distance_in_action_space,
self.config["exploration_ou_sigma"], self.cur_noise_scale)
# support both hard and soft sync
def update_target(self, tau=None):
tau = tau or self.tau_value
return self.sess.run(
self.update_target_expr, feed_dict={self.tau: tau})
def set_epsilon(self, epsilon):
# set_epsilon is called by optimizer to anneal exploration as
@ -406,117 +652,6 @@ class ExplorationStateMixin(object):
# is a carry-over from DQN, which uses epsilon-greedy exploration
# rather than adding action noise to the output of a policy network.
self.cur_noise_scale = epsilon
self.noise_scale.load(self.cur_noise_scale, self.get_session())
def set_pure_exploration_phase(self, pure_exploration_phase):
self.cur_pure_exploration_phase = pure_exploration_phase
self.pure_exploration_phase.load(self.cur_pure_exploration_phase,
self.get_session())
@override(Policy)
def get_state(self):
return [
TFPolicy.get_state(self), self.cur_noise_scale,
self.cur_pure_exploration_phase
]
@override(Policy)
def set_state(self, state):
TFPolicy.set_state(self, state[0])
self.set_epsilon(state[1])
self.set_pure_exploration_phase(state[2])
class TargetNetworkMixin(object):
def __init__(self, config):
@make_tf_callable(self.get_session())
def update_target_fn(tau):
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
update_target_expr = []
model_vars = self.model.trainable_variables()
target_model_vars = self.target_model.trainable_variables()
assert len(model_vars) == len(target_model_vars), \
(model_vars, target_model_vars)
for var, var_target in zip(model_vars, target_model_vars):
update_target_expr.append(
var_target.assign(tau * var + (1.0 - tau) * var_target))
logger.debug("Update target op {}".format(var_target))
return tf.group(*update_target_expr)
# Hard initial update
self._do_update = update_target_fn
self.update_target(tau=1.0)
# support both hard and soft sync
def update_target(self, tau=None):
self._do_update(np.float32(tau or self.config.get("tau")))
class ActorCriticOptimizerMixin(object):
def __init__(self, config):
# create global step for counting the number of update operations
self.global_step = tf.train.get_or_create_global_step()
# use separate optimizers for actor & critic
self._actor_optimizer = tf.train.AdamOptimizer(
learning_rate=config["actor_lr"])
self._critic_optimizer = tf.train.AdamOptimizer(
learning_rate=config["critic_lr"])
class ComputeTDErrorMixin(object):
def __init__(self):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
if not self.loss_initialized():
return tf.zeros_like(rew_t)
# Do forward pass on loss to update td error attribute
actor_critic_loss(
self, self.model, None, {
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
})
return self.td_error
self.compute_td_error = compute_td_error
def setup_early_mixins(policy, obs_space, action_space, config):
ExplorationStateMixin.__init__(policy, obs_space, action_space, config)
ActorCriticOptimizerMixin.__init__(policy, config)
def setup_mid_mixins(policy, obs_space, action_space, config):
ComputeTDErrorMixin.__init__(policy)
def setup_late_mixins(policy, obs_space, action_space, config):
TargetNetworkMixin.__init__(policy, config)
DDPGTFPolicy = build_tf_policy(
name="DDPGTFPolicy",
get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG,
make_model=build_ddpg_model,
postprocess_fn=postprocess_trajectory,
action_sampler_fn=build_action_output,
loss_fn=actor_critic_loss,
stats_fn=stats,
gradients_fn=gradients,
apply_gradients_fn=apply_gradients,
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
mixins=[
TargetNetworkMixin, ExplorationStateMixin, ActorCriticOptimizerMixin,
ComputeTDErrorMixin
],
before_init=setup_early_mixins,
before_loss_init=setup_mid_mixins,
after_init=setup_late_mixins,
obs_include_prev_action_reward=False)

View file

@ -186,8 +186,9 @@ def check_config_and_setup_param_noise(config):
# between noisy policy and original policy
policies = info["policy"]
episode = info["episode"]
episode.custom_metrics["policy_distance"] = policies[
DEFAULT_POLICY_ID].model.pi_distance
model = policies[DEFAULT_POLICY_ID].model
if hasattr(model, "pi_distance"):
episode.custom_metrics["policy_distance"] = model.pi_distance
if end_callback:
end_callback(info)

View file

@ -10,15 +10,13 @@ import ray
import ray.experimental.tf_utils
from ray.rllib.agents.sac.sac_model import SACModel
from ray.rllib.agents.ddpg.noop_model import NoopModel
from ray.rllib.agents.ddpg.ddpg_policy import ComputeTDErrorMixin, \
TargetNetworkMixin
from ray.rllib.agents.dqn.dqn_policy import _postprocess_dqn, PRIO_WEIGHTS
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils import try_import_tf, try_import_tfp
from ray.rllib.utils.tf_ops import minimize_and_clip
from ray.rllib.utils.tf_ops import minimize_and_clip, make_tf_callable
tf = try_import_tf()
tfp = try_import_tfp()
@ -287,6 +285,55 @@ class ActorCriticOptimizerMixin(object):
learning_rate=config["optimization"]["entropy_learning_rate"])
class ComputeTDErrorMixin(object):
def __init__(self):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
if not self.loss_initialized():
return tf.zeros_like(rew_t)
# Do forward pass on loss to update td error attribute
actor_critic_loss(
self, self.model, None, {
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
})
return self.td_error
self.compute_td_error = compute_td_error
class TargetNetworkMixin(object):
def __init__(self, config):
@make_tf_callable(self.get_session())
def update_target_fn(tau):
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
update_target_expr = []
model_vars = self.model.trainable_variables()
target_model_vars = self.target_model.trainable_variables()
assert len(model_vars) == len(target_model_vars), \
(model_vars, target_model_vars)
for var, var_target in zip(model_vars, target_model_vars):
update_target_expr.append(
var_target.assign(tau * var + (1.0 - tau) * var_target))
logger.debug("Update target op {}".format(var_target))
return tf.group(*update_target_expr)
# Hard initial update
self._do_update = update_target_fn
self.update_target(tau=1.0)
# support both hard and soft sync
def update_target(self, tau=None):
self._do_update(np.float32(tau or self.config.get("tau")))
def setup_early_mixins(policy, obs_space, action_space, config):
ExplorationStateMixin.__init__(policy, obs_space, action_space, config)
ActorCriticOptimizerMixin.__init__(policy, config)

View file

@ -56,30 +56,6 @@ class TestEagerSupport(unittest.TestCase):
"timesteps_per_iteration": 100
})
def testDDPG(self):
check_support("DDPG", {
"num_workers": 0,
"learning_starts": 0,
"timesteps_per_iteration": 10
})
def testTD3(self):
check_support("TD3", {
"num_workers": 0,
"learning_starts": 0,
"timesteps_per_iteration": 10
})
def testAPEX_DDPG(self):
check_support(
"APEX_DDPG", {
"num_workers": 2,
"learning_starts": 0,
"num_gpus": 0,
"min_iter_time_s": 1,
"timesteps_per_iteration": 100
})
def testSAC(self):
check_support("SAC", {
"num_workers": 0,