[rllib] Contribute DDPG to RLlib (#1877)

*  ongoing ddpg

*  ongoing ddpg converged

*  gpu machine changes

*  tuned

*  tuned ddpg specification

*  ddpg

*  supplement missed optimizer argument clip_rewards in default DQN configuration

*  ddpg supports vision env (atari) now

*  revised according to code review comments

*  added regression test case

*  removed irrelevant files

*  validate ddpg on mountain_car_continuous

*  restore unnecessary slight changes

*  revised according to eric's comments

*  added the requested tests

*  revised accordingly

*  revised accordingly and re-validated

*  formatted by yapf

*  fix lint errors

*  formatted by yapf

*  fix lint errors

*  formatted by yapf

*  fix lint error
This commit is contained in:
Jones Wong 2018-04-19 22:36:29 -07:00 committed by Eric Liang
parent aa07f1ce4e
commit c9a7744e52
16 changed files with 1013 additions and 6 deletions

View file

@ -9,7 +9,8 @@ from ray.tune.registry import register_trainable
def _register_all():
for key in ["PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG",
"__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
"DDPG2", "APEX_DDPG2", "__fake", "__sigmoid_fake_data",
"__parameter_tuning"]:
from ray.rllib.agent import get_agent_class
register_trainable(key, get_agent_class(key))

View file

@ -231,7 +231,13 @@ class _ParameterTuningAgent(_MockAgent):
def get_agent_class(alg):
"""Returns the class of a known agent given its name."""
if alg == "PPO":
if alg == "DDPG2":
from ray.rllib import ddpg2
return ddpg2.DDPG2Agent
elif alg == "APEX_DDPG2":
from ray.rllib import ddpg2
return ddpg2.ApexDDPG2Agent
elif alg == "PPO":
from ray.rllib import ppo
return ppo.PPOAgent
elif alg == "ES":

View file

@ -0,0 +1 @@
Code in this package follows the style of dqn.

View file

@ -0,0 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.ddpg2.apex import ApexDDPG2Agent
from ray.rllib.ddpg2.ddpg import DDPG2Agent, DEFAULT_CONFIG
__all__ = ["DDPG2Agent", "ApexDDPG2Agent", "DEFAULT_CONFIG"]

View file

@ -0,0 +1,47 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.ddpg2.ddpg import DDPG2Agent, DEFAULT_CONFIG as DDPG_CONFIG
APEX_DDPG_DEFAULT_CONFIG = dict(DDPG_CONFIG,
**dict(
optimizer_class="ApexOptimizer",
optimizer_config=dict(
DDPG_CONFIG["optimizer_config"],
**dict(
max_weight_sync_delay=400,
num_replay_buffer_shards=4,
debug=False,
)),
n_step=3,
num_workers=32,
buffer_size=2000000,
learning_starts=50000,
train_batch_size=512,
sample_batch_size=50,
max_weight_sync_delay=400,
target_network_update_freq=500000,
timesteps_per_iteration=25000,
per_worker_exploration=True,
worker_side_prioritization=True,
))
class ApexDDPG2Agent(DDPG2Agent):
"""DDPG variant that uses the Ape-X distributed policy optimizer.
By default, this is configured for a large single node (32 cores). For
running in a large cluster, increase the `num_workers` config var.
"""
_agent_name = "APEX_DDPG"
_default_config = APEX_DDPG_DEFAULT_CONFIG
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.update_target()
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1

View file

@ -0,0 +1,268 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import os
import numpy as np
import tensorflow as tf
import ray
from ray.rllib import optimizers
from ray.rllib.ddpg2.ddpg_evaluator import DDPGEvaluator
from ray.rllib.agent import Agent
from ray.tune.result import TrainingResult
OPTIMIZER_SHARED_CONFIGS = [
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
"prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size",
"train_batch_size", "learning_starts", "clip_rewards"
]
DEFAULT_CONFIG = dict(
# === Model ===
# Hidden layer sizes of the policy networks
actor_hiddens=[64, 64],
# Hidden layer sizes of the policy networks
critic_hiddens=[64, 64],
# N-step Q learning
n_step=1,
# Config options to pass to the model constructor
model={},
# Discount factor for the MDP
gamma=0.99,
# Arguments to pass to the env creator
env_config={},
# === Exploration ===
# Max num timesteps for annealing schedules. Exploration is annealed from
# 1.0 to exploration_fraction over this number of timesteps scaled by
# exploration_fraction
schedule_max_timesteps=100000,
# Number of env steps to optimize for before returning
timesteps_per_iteration=1000,
# Fraction of entire training period over which the exploration rate is
# annealed
exploration_fraction=0.1,
# Final value of random action probability
exploration_final_eps=0.02,
# OU-noise scale
noise_scale=0.1,
# theta
exploration_theta=0.15,
# sigma
exploration_sigma=0.2,
# Update the target network every `target_network_update_freq` steps.
target_network_update_freq=0,
# Update the target by \tau * policy + (1-\tau) * target_policy
tau=0.002,
# Whether to start with random actions instead of noops.
random_starts=True,
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
buffer_size=50000,
# If True prioritized replay buffer will be used.
prioritized_replay=True,
# Alpha parameter for prioritized replay buffer.
prioritized_replay_alpha=0.6,
# Beta parameter for sampling from prioritized replay buffer.
prioritized_replay_beta=0.4,
# Epsilon to add to the TD errors when updating priorities.
prioritized_replay_eps=1e-6,
# Whether to clip rewards to [-1, 1] prior to adding to the replay buffer.
clip_rewards=True,
# === Optimization ===
# Learning rate for adam optimizer
actor_lr=1e-4,
critic_lr=1e-3,
# If True, use huber loss instead of squared loss for critic network
# Conventionally, no need to clip gradients if using a huber loss
use_huber=False,
# Threshold of a huber loss
huber_threshold=1.0,
# Weights for L2 regularization
l2_reg=1e-6,
# If not None, clip gradients during optimization at this value
grad_norm_clipping=None,
# How many steps of the model to sample before learning starts.
learning_starts=1500,
# Update the replay buffer with this many samples at once. Note that this
# setting applies per-worker if num_workers > 1.
sample_batch_size=1,
# Size of a batched sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
train_batch_size=256,
# Smooth the current average reward over this many previous episodes.
smoothing_num_episodes=100,
# === Tensorflow ===
# Arguments to pass to tensorflow
tf_session_args={
"device_count": {
"CPU": 2
},
"log_device_placement": False,
"allow_soft_placement": True,
"gpu_options": {
"allow_growth": True
},
"inter_op_parallelism_threads": 1,
"intra_op_parallelism_threads": 1,
},
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you're using the Async or Ape-X optimizers.
num_workers=0,
# Whether to allocate GPUs for workers (if > 0).
num_gpus_per_worker=0,
# Optimizer class to use.
optimizer_class="LocalSyncReplayOptimizer",
# Config to pass to the optimizer.
optimizer_config=dict(),
# Whether to use a distribution of epsilons across workers for exploration.
per_worker_exploration=False,
# Whether to compute priorities on workers.
worker_side_prioritization=False)
class DDPG2Agent(Agent):
_agent_name = "DDPG2"
_allow_unknown_subkeys = [
"model", "optimizer", "tf_session_args", "env_config"
]
_default_config = DEFAULT_CONFIG
def _init(self):
self.local_evaluator = DDPGEvaluator(self.registry, self.env_creator,
self.config, self.logdir, 0)
remote_cls = ray.remote(
num_cpus=1,
num_gpus=self.config["num_gpus_per_worker"])(DDPGEvaluator)
self.remote_evaluators = [
remote_cls.remote(self.registry, self.env_creator, self.config,
self.logdir, i)
for i in range(self.config["num_workers"])
]
for k in OPTIMIZER_SHARED_CONFIGS:
if k not in self.config["optimizer_config"]:
self.config["optimizer_config"][k] = self.config[k]
self.optimizer = getattr(optimizers, self.config["optimizer_class"])(
self.config["optimizer_config"], self.local_evaluator,
self.remote_evaluators)
self.saver = tf.train.Saver(max_to_keep=None)
self.last_target_update_ts = 0
self.num_target_updates = 0
@property
def global_timestep(self):
return self.optimizer.num_steps_sampled
def update_target_if_needed(self):
if self.global_timestep - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.update_target()
self.last_target_update_ts = self.global_timestep
self.num_target_updates += 1
def _train(self):
start_timestep = self.global_timestep
while (self.global_timestep - start_timestep <
self.config["timesteps_per_iteration"]):
self.optimizer.step()
self.update_target_if_needed()
self.local_evaluator.set_global_timestep(self.global_timestep)
for e in self.remote_evaluators:
e.set_global_timestep.remote(self.global_timestep)
return self._train_stats(start_timestep)
def _train_stats(self, start_timestep):
if self.remote_evaluators:
stats = ray.get([e.stats.remote() for e in self.remote_evaluators])
else:
stats = self.local_evaluator.stats()
if not isinstance(stats, list):
stats = [stats]
mean_100ep_reward = 0.0
mean_100ep_length = 0.0
num_episodes = 0
explorations = []
if self.config["per_worker_exploration"]:
# Return stats from workers with the lowest 20% of exploration
test_stats = stats[-int(max(1, len(stats) * 0.2)):]
else:
test_stats = stats
for s in test_stats:
mean_100ep_reward += s["mean_100ep_reward"] / len(test_stats)
mean_100ep_length += s["mean_100ep_length"] / len(test_stats)
for s in stats:
num_episodes += s["num_episodes"]
explorations.append(s["exploration"])
opt_stats = self.optimizer.stats()
result = TrainingResult(
episode_reward_mean=mean_100ep_reward,
episode_len_mean=mean_100ep_length,
episodes_total=num_episodes,
timesteps_this_iter=self.global_timestep - start_timestep,
info=dict({
"min_exploration": min(explorations),
"max_exploration": max(explorations),
"num_target_updates": self.num_target_updates,
}, **opt_stats))
return result
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
def _save(self, checkpoint_dir):
checkpoint_path = self.saver.save(
self.local_evaluator.sess,
os.path.join(checkpoint_dir, "checkpoint"),
global_step=self.iteration)
extra_data = [
self.local_evaluator.save(),
ray.get([e.save.remote() for e in self.remote_evaluators]),
self.optimizer.save(), self.num_target_updates,
self.last_target_update_ts
]
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path
def _restore(self, checkpoint_path):
self.saver.restore(self.local_evaluator.sess, checkpoint_path)
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
self.local_evaluator.restore(extra_data[0])
ray.get([
e.restore.remote(d)
for (d, e) in zip(extra_data[1], self.remote_evaluators)
])
self.optimizer.restore(extra_data[2])
self.num_target_updates = extra_data[3]
self.last_target_update_ts = extra_data[4]
def compute_action(self, observation):
return self.local_evaluator.ddpg_graph.act(self.local_evaluator.sess,
np.array(observation)[None],
0.0)[0]

View file

@ -0,0 +1,186 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from gym.spaces import Box
import numpy as np
import tensorflow as tf
import ray
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.ddpg2 import models
from ray.rllib.dqn.common.schedules import ConstantSchedule, LinearSchedule
from ray.rllib.optimizers import SampleBatch, PolicyEvaluator
from ray.rllib.utils.compression import pack
from ray.rllib.dqn.dqn_evaluator import adjust_nstep
from ray.rllib.dqn.common.wrappers import wrap_dqn
class DDPGEvaluator(PolicyEvaluator):
"""The base DDPG Evaluator."""
def __init__(self, registry, env_creator, config, logdir, worker_index):
env = env_creator(config["env_config"])
env = wrap_dqn(registry, env, config["model"], config["random_starts"])
self.env = env
self.config = config
# when env.action_space is of Box type, e.g., Pendulum-v0
# action_space.low is [-2.0], high is [2.0]
# take action by calling, e.g., env.step([3.5])
if not isinstance(env.action_space, Box):
raise UnsupportedSpaceException(
"Action space {} is not supported for DDPG.".format(
env.action_space))
tf_config = tf.ConfigProto(**config["tf_session_args"])
self.sess = tf.Session(config=tf_config)
self.ddpg_graph = models.DDPGGraph(registry, env, config, logdir)
# Use either a different `eps` per worker, or a linear schedule.
if config["per_worker_exploration"]:
assert config["num_workers"] > 1, "This requires multiple workers"
self.exploration = ConstantSchedule(
config["noise_scale"] * 0.4 **
(1 + worker_index / float(config["num_workers"] - 1) * 7))
else:
self.exploration = LinearSchedule(
schedule_timesteps=int(config["exploration_fraction"] *
config["schedule_max_timesteps"]),
initial_p=config["noise_scale"] * 1.0,
final_p=config["noise_scale"] *
config["exploration_final_eps"])
# Initialize the parameters and copy them to the target network.
self.sess.run(tf.global_variables_initializer())
# hard instead of soft
self.ddpg_graph.update_target(self.sess, 1.0)
self.global_timestep = 0
self.local_timestep = 0
# Note that this encompasses both the policy and Q-value networks and
# their corresponding target networks
self.variables = ray.experimental.TensorFlowVariables(
tf.group(self.ddpg_graph.q_tp0, self.ddpg_graph.q_tp1), self.sess)
self.episode_rewards = [0.0]
self.episode_lengths = [0.0]
self.saved_mean_reward = None
self.obs = self.env.reset()
def set_global_timestep(self, global_timestep):
self.global_timestep = global_timestep
def update_target(self):
self.ddpg_graph.update_target(self.sess)
def sample(self):
obs, actions, rewards, new_obs, dones = [], [], [], [], []
for _ in range(
self.config["sample_batch_size"] + self.config["n_step"] - 1):
ob, act, rew, ob1, done = self._step(self.global_timestep)
obs.append(ob)
actions.append(act)
rewards.append(rew)
new_obs.append(ob1)
dones.append(done)
# N-step Q adjustments
if self.config["n_step"] > 1:
# Adjust for steps lost from truncation
self.local_timestep -= (self.config["n_step"] - 1)
adjust_nstep(self.config["n_step"], self.config["gamma"], obs,
actions, rewards, new_obs, dones)
batch = SampleBatch({
"obs": [pack(np.array(o)) for o in obs],
"actions": actions,
"rewards": rewards,
"new_obs": [pack(np.array(o)) for o in new_obs],
"dones": dones,
"weights": np.ones_like(rewards)
})
assert (batch.count == self.config["sample_batch_size"])
# Prioritize on the worker side
if self.config["worker_side_prioritization"]:
td_errors = self.ddpg_graph.compute_td_error(
self.sess, obs, batch["actions"], batch["rewards"], new_obs,
batch["dones"], batch["weights"])
new_priorities = (
np.abs(td_errors) + self.config["prioritized_replay_eps"])
batch.data["weights"] = new_priorities
return batch
def compute_gradients(self, samples):
td_err, grads = self.ddpg_graph.compute_gradients(
self.sess, samples["obs"], samples["actions"], samples["rewards"],
samples["new_obs"], samples["dones"], samples["weights"])
return grads, {"td_error": td_err}
def apply_gradients(self, grads):
self.ddpg_graph.apply_gradients(self.sess, grads)
def compute_apply(self, samples):
td_error = self.ddpg_graph.compute_apply(
self.sess, samples["obs"], samples["actions"], samples["rewards"],
samples["new_obs"], samples["dones"], samples["weights"])
return {"td_error": td_error}
def get_weights(self):
return self.variables.get_weights()
def set_weights(self, weights):
self.variables.set_weights(weights)
def _step(self, global_timestep):
"""Takes a single step, and returns the result of the step."""
action = self.ddpg_graph.act(
self.sess,
np.array(self.obs)[None],
self.exploration.value(global_timestep))[0]
new_obs, rew, done, _ = self.env.step(action)
ret = (self.obs, action, rew, new_obs, float(done))
self.obs = new_obs
self.episode_rewards[-1] += rew
self.episode_lengths[-1] += 1
if done:
self.obs = self.env.reset()
self.episode_rewards.append(0.0)
self.episode_lengths.append(0.0)
# reset UO noise for each episode
self.ddpg_graph.reset_noise(self.sess)
self.local_timestep += 1
return ret
def stats(self):
n = self.config["smoothing_num_episodes"] + 1
mean_100ep_reward = round(np.mean(self.episode_rewards[-n:-1]), 5)
mean_100ep_length = round(np.mean(self.episode_lengths[-n:-1]), 5)
exploration = self.exploration.value(self.global_timestep)
return {
"mean_100ep_reward": mean_100ep_reward,
"mean_100ep_length": mean_100ep_length,
"num_episodes": len(self.episode_rewards),
"exploration": exploration,
"local_timestep": self.local_timestep,
}
def save(self):
return [
self.exploration, self.episode_rewards, self.episode_lengths,
self.saved_mean_reward, self.obs, self.global_timestep,
self.local_timestep
]
def restore(self, data):
self.exploration = data[0]
self.episode_rewards = data[1]
self.episode_lengths = data[2]
self.saved_mean_reward = data[3]
self.obs = data[4]
self.global_timestep = data[5]
self.local_timestep = data[6]

View file

@ -0,0 +1,391 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers
from ray.rllib.models import ModelCatalog
def _build_p_network(registry, inputs, dim_actions, config):
"""
map an observation (i.e., state) to an action where
each entry takes value from (0, 1) due to the sigmoid function
"""
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
hiddens = config["actor_hiddens"]
action_out = frontend.last_layer
for hidden in hiddens:
action_out = layers.fully_connected(
action_out, num_outputs=hidden, activation_fn=tf.nn.relu)
# Use sigmoid layer to bound values within (0, 1)
# shape of action_scores is [batch_size, dim_actions]
action_scores = layers.fully_connected(
action_out, num_outputs=dim_actions, activation_fn=tf.nn.sigmoid)
return action_scores
# As a stochastic policy for inference, but a deterministic policy for training
# thus ignore batch_size issue when constructing a stochastic action
def _build_action_network(p_values, low_action, high_action, stochastic, eps,
theta, sigma):
# shape is [None, dim_action]
deterministic_actions = (high_action - low_action) * p_values + low_action
exploration_sample = tf.get_variable(
name="ornstein_uhlenbeck",
dtype=tf.float32,
initializer=low_action.size * [.0],
trainable=False)
normal_sample = tf.random_normal(
shape=[low_action.size], mean=0.0, stddev=1.0)
exploration_value = tf.assign_add(
exploration_sample,
theta * (.0 - exploration_sample) + sigma * normal_sample)
stochastic_actions = deterministic_actions + eps * (
high_action - low_action) * exploration_value
return tf.cond(stochastic, lambda: stochastic_actions,
lambda: deterministic_actions)
def _build_q_network(registry, inputs, action_inputs, config):
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
hiddens = config["critic_hiddens"]
q_out = tf.concat([frontend.last_layer, action_inputs], axis=1)
for hidden in hiddens:
q_out = layers.fully_connected(
q_out, num_outputs=hidden, activation_fn=tf.nn.relu)
q_scores = layers.fully_connected(q_out, num_outputs=1, activation_fn=None)
return q_scores
def _huber_loss(x, delta=1.0):
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
return tf.where(
tf.abs(x) < delta,
tf.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta))
def _minimize_and_clip(optimizer, objective, var_list, clip_val=10):
"""Minimized `objective` using `optimizer` w.r.t. variables in
`var_list` while ensure the norm of the gradients for each
variable is clipped to `clip_val`
"""
gradients = optimizer.compute_gradients(objective, var_list=var_list)
for i, (grad, var) in enumerate(gradients):
if grad is not None:
gradients[i] = (tf.clip_by_norm(grad, clip_val), var)
return gradients
def _scope_vars(scope, trainable_only=False):
"""
Get variables inside a scope
The scope can be specified as a string
Parameters
----------
scope: str or VariableScope
scope in which the variables reside.
trainable_only: bool
whether or not to return only the variables that were marked as
trainable.
Returns
-------
vars: [tf.Variable]
list of variables in `scope`.
"""
return tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES
if trainable_only else tf.GraphKeys.VARIABLES,
scope=scope if isinstance(scope, str) else scope.name)
class ModelAndLoss(object):
"""Holds the model and loss function.
Both graphs are necessary in order for the multi-gpu SGD implementation
to create towers on each device.
"""
def __init__(self, registry, dim_actions, low_action, high_action, config,
obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights):
# p network evaluation
with tf.variable_scope("p_func", reuse=True) as scope:
self.p_t = _build_p_network(registry, obs_t, dim_actions, config)
# target p network evaluation
with tf.variable_scope("target_p_func") as scope:
self.p_tp1 = _build_p_network(registry, obs_tp1, dim_actions,
config)
self.target_p_func_vars = _scope_vars(scope.name)
# Action outputs
with tf.variable_scope("a_func", reuse=True):
deterministic_flag = tf.constant(value=False, dtype=tf.bool)
zero_eps = tf.constant(value=.0, dtype=tf.float32)
output_actions = _build_action_network(
self.p_t, low_action, high_action, deterministic_flag,
zero_eps, config["exploration_theta"],
config["exploration_sigma"])
output_actions_estimated = _build_action_network(
self.p_tp1, low_action, high_action, deterministic_flag,
zero_eps, config["exploration_theta"],
config["exploration_sigma"])
# q network evaluation
with tf.variable_scope("q_func") as scope:
self.q_t = _build_q_network(registry, obs_t, act_t, config)
self.q_func_vars = _scope_vars(scope.name)
with tf.variable_scope("q_func", reuse=True):
self.q_tp0 = _build_q_network(registry, obs_t, output_actions,
config)
# target q network evalution
with tf.variable_scope("target_q_func") as scope:
self.q_tp1 = _build_q_network(registry, obs_tp1,
output_actions_estimated, config)
self.target_q_func_vars = _scope_vars(scope.name)
q_t_selected = tf.squeeze(self.q_t, axis=len(self.q_t.shape) - 1)
q_tp1_best = tf.squeeze(
input=self.q_tp1, axis=len(self.q_tp1.shape) - 1)
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = (
rew_t + config["gamma"]**config["n_step"] * q_tp1_best_masked)
# compute the error (potentially clipped)
self.td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
if config.get("use_huber"):
errors = _huber_loss(self.td_error, config.get("huber_threshold"))
else:
errors = 0.5 * tf.square(self.td_error)
weighted_error = tf.reduce_mean(importance_weights * errors)
self.loss = weighted_error
# for policy gradient
self.actor_loss = -1.0 * tf.reduce_mean(self.q_tp0)
class DDPGGraph(object):
def __init__(self, registry, env, config, logdir):
self.env = env
dim_actions = env.action_space.shape[0]
low_action = env.action_space.low
high_action = env.action_space.high
actor_optimizer = tf.train.AdamOptimizer(
learning_rate=config["actor_lr"])
critic_optimizer = tf.train.AdamOptimizer(
learning_rate=config["critic_lr"])
# Action inputs
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
self.eps = tf.placeholder(tf.float32, (), name="eps")
self.cur_observations = tf.placeholder(
tf.float32, shape=(None, ) + env.observation_space.shape)
# Actor: P (policy) network
p_scope_name = "p_func"
with tf.variable_scope(p_scope_name) as scope:
p_values = _build_p_network(registry, self.cur_observations,
dim_actions, config)
p_func_vars = _scope_vars(scope.name)
# Action outputs
a_scope_name = "a_func"
with tf.variable_scope(a_scope_name):
self.output_actions = _build_action_network(
p_values, low_action, high_action, self.stochastic, self.eps,
config["exploration_theta"], config["exploration_sigma"])
with tf.variable_scope(a_scope_name, reuse=True):
exploration_sample = tf.get_variable(name="ornstein_uhlenbeck")
self.reset_noise_op = tf.assign(exploration_sample,
dim_actions * [.0])
# Replay inputs
self.obs_t = tf.placeholder(
tf.float32,
shape=(None, ) + env.observation_space.shape,
name="observation")
self.act_t = tf.placeholder(
tf.float32, shape=(None, ) + env.action_space.shape, name="action")
self.rew_t = tf.placeholder(tf.float32, [None], name="reward")
self.obs_tp1 = tf.placeholder(
tf.float32, shape=(None, ) + env.observation_space.shape)
self.done_mask = tf.placeholder(tf.float32, [None], name="done")
self.importance_weights = tf.placeholder(
tf.float32, [None], name="weight")
def build_loss(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
return ModelAndLoss(registry, dim_actions, low_action, high_action,
config, obs_t, act_t, rew_t, obs_tp1,
done_mask, importance_weights)
self.loss_inputs = [
("obs", self.obs_t),
("actions", self.act_t),
("rewards", self.rew_t),
("new_obs", self.obs_tp1),
("dones", self.done_mask),
("weights", self.importance_weights),
]
loss_obj = build_loss(self.obs_t, self.act_t, self.rew_t, self.obs_tp1,
self.done_mask, self.importance_weights)
self.build_loss = build_loss
actor_loss = loss_obj.actor_loss
weighted_error = loss_obj.loss
q_func_vars = loss_obj.q_func_vars
target_p_func_vars = loss_obj.target_p_func_vars
target_q_func_vars = loss_obj.target_q_func_vars
self.p_t = loss_obj.p_t
self.q_t = loss_obj.q_t
self.q_tp0 = loss_obj.q_tp0
self.q_tp1 = loss_obj.q_tp1
self.td_error = loss_obj.td_error
if config["l2_reg"] is not None:
for var in p_func_vars:
if "bias" not in var.name:
actor_loss += config["l2_reg"] * 0.5 * tf.nn.l2_loss(var)
for var in q_func_vars:
if "bias" not in var.name:
weighted_error += config["l2_reg"] * 0.5 * tf.nn.l2_loss(
var)
# compute optimization op (potentially with gradient clipping)
if config["grad_norm_clipping"] is not None:
self.actor_grads_and_vars = _minimize_and_clip(
actor_optimizer,
actor_loss,
var_list=p_func_vars,
clip_val=config["grad_norm_clipping"])
self.critic_grads_and_vars = _minimize_and_clip(
critic_optimizer,
weighted_error,
var_list=q_func_vars,
clip_val=config["grad_norm_clipping"])
else:
self.actor_grads_and_vars = actor_optimizer.compute_gradients(
actor_loss, var_list=p_func_vars)
self.critic_grads_and_vars = critic_optimizer.compute_gradients(
weighted_error, var_list=q_func_vars)
self.actor_grads_and_vars = [(g, v)
for (g, v) in self.actor_grads_and_vars
if g is not None]
self.critic_grads_and_vars = [(g, v)
for (g, v) in self.critic_grads_and_vars
if g is not None]
self.grads_and_vars = (
self.actor_grads_and_vars + self.critic_grads_and_vars)
self.grads = [g for (g, v) in self.grads_and_vars]
self.actor_train_expr = actor_optimizer.apply_gradients(
self.actor_grads_and_vars)
self.critic_train_expr = critic_optimizer.apply_gradients(
self.critic_grads_and_vars)
# update_target_fn will be called periodically to copy Q network to
# target Q network
self.tau_value = config.get("tau")
self.tau = tf.placeholder(tf.float32, (), name="tau")
update_target_expr = []
for var, var_target in zip(
sorted(q_func_vars, key=lambda v: v.name),
sorted(target_q_func_vars, key=lambda v: v.name)):
update_target_expr.append(
var_target.assign(self.tau * var +
(1.0 - self.tau) * var_target))
for var, var_target in zip(
sorted(p_func_vars, key=lambda v: v.name),
sorted(target_p_func_vars, key=lambda v: v.name)):
update_target_expr.append(
var_target.assign(self.tau * var +
(1.0 - self.tau) * var_target))
self.update_target_expr = tf.group(*update_target_expr)
# support both hard and soft sync
def update_target(self, sess, tau=None):
return sess.run(
self.update_target_expr,
feed_dict={self.tau: tau or self.tau_value})
def act(self, sess, obs, eps, stochastic=True):
return sess.run(
self.output_actions,
feed_dict={
self.cur_observations: obs,
self.stochastic: stochastic,
self.eps: eps
})
def compute_gradients(self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
td_err, grads = sess.run(
[self.td_error, self.grads],
feed_dict={
self.obs_t: obs_t,
self.act_t: act_t,
self.rew_t: rew_t,
self.obs_tp1: obs_tp1,
self.done_mask: done_mask,
self.importance_weights: importance_weights
})
return td_err, grads
def compute_td_error(self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
td_err = sess.run(
self.td_error,
feed_dict={
self.obs_t: [np.array(ob) for ob in obs_t],
self.act_t: act_t,
self.rew_t: rew_t,
self.obs_tp1: [np.array(ob) for ob in obs_tp1],
self.done_mask: done_mask,
self.importance_weights: importance_weights
})
return td_err
def apply_gradients(self, sess, grads):
assert len(grads) == len(self.grads_and_vars)
feed_dict = {ph: g for (g, ph) in zip(grads, self.grads)}
sess.run(
[self.critic_train_expr, self.actor_train_expr],
feed_dict=feed_dict)
def compute_apply(self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
td_err, _, _ = sess.run(
[self.td_error, self.critic_train_expr, self.actor_train_expr],
feed_dict={
self.obs_t: obs_t,
self.act_t: act_t,
self.rew_t: rew_t,
self.obs_tp1: obs_tp1,
self.done_mask: done_mask,
self.importance_weights: importance_weights
})
return td_err
def reset_noise(self, sess):
sess.run(self.reset_noise_op)

View file

@ -22,6 +22,7 @@ ray.init()
CONFIGS = {
"ES": {"episodes_per_batch": 10, "timesteps_per_batch": 100},
"DQN": {},
"DDPG2": {"noise_scale": 0.0},
"PPO": {"num_sgd_iter": 5, "timesteps_per_batch": 1000},
"A3C": {"use_lstm": False},
}
@ -29,8 +30,12 @@ CONFIGS = {
def test(use_object_store, alg_name):
cls = get_agent_class(alg_name)
alg1 = cls(config=CONFIGS[name], env="CartPole-v0")
alg2 = cls(config=CONFIGS[name], env="CartPole-v0")
if alg_name == "DDPG2":
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
alg2 = cls(config=CONFIGS[name], env="Pendulum-v0")
else:
alg1 = cls(config=CONFIGS[name], env="CartPole-v0")
alg2 = cls(config=CONFIGS[name], env="CartPole-v0")
for _ in range(3):
res = alg1.train()
@ -43,7 +48,10 @@ def test(use_object_store, alg_name):
alg2.restore(alg1.save())
for _ in range(10):
obs = np.random.uniform(size=4)
if alg_name == "DDPG2":
obs = np.random.uniform(size=3)
else:
obs = np.random.uniform(size=4)
a1 = get_mean_action(alg1, obs)
a2 = get_mean_action(alg2, obs)
print("Checking computed actions", alg1, obs, a1, a2)
@ -53,7 +61,7 @@ def test(use_object_store, alg_name):
if __name__ == "__main__":
# https://github.com/ray-project/ray/issues/1062 for enabling ES test too
for use_object_store in [False, True]:
for name in ["ES", "DQN", "PPO", "A3C"]:
for name in ["ES", "DQN", "DDPG2", "PPO", "A3C"]:
test(use_object_store, name)
print("All checkpoint restore tests passed!")

View file

@ -114,6 +114,7 @@ class ModelSupportedSpaces(unittest.TestCase):
def testAll(self):
ray.init()
stats = {}
check_support("DDPG2", {"timesteps_per_iteration": 1}, stats)
check_support("DQN", {"timesteps_per_iteration": 1}, stats)
check_support(
"A3C", {"num_workers": 1, "optimizer": {"grads_per_step": 1}},

View file

@ -0,0 +1,18 @@
# This can be expected to reach 90 reward within ~1.5-2.5m timesteps / ~150-250 seconds on a K40 GPU
mountaincarcontinuous-apex-ddpg-2:
env: MountainCarContinuous-v0
run: APEX_DDPG2
trial_resources:
cpu: 1
gpu: 1
extra_cpu:
eval: 4 + spec.config.num_workers
stop:
episode_reward_mean: 90
config:
clip_rewards: False
num_workers: 16
noise_scale: 1.0
n_step: 3
target_network_update_freq: 50000
tau: 1.0

View file

@ -0,0 +1,22 @@
# can expect improvement to 90 reward in ~12-24k timesteps
mountaincarcontinuous-ddpg-2:
env: MountainCarContinuous-v0
run: DDPG2
trial_resources:
cpu: 6
stop:
episode_reward_mean: 90
config:
n_step: 3
actor_hiddens: [32, 64]
critic_hiddens: [64, 64]
noise_scale: 0.75
exploration_fraction: 0.4
tau: 0.01
l2_reg: 0.00001
buffer_size: 50000
random_starts: False
clip_rewards: False
learning_starts: 1000
#model:
# fcnet_hiddens: []

View file

@ -0,0 +1,18 @@
# This can be expected to reach -160 reward within 2.5 timesteps / ~250 seconds on a K40 GPU
pendulum-apex-ddpg-2:
env: Pendulum-v0
run: APEX_DDPG2
trial_resources:
cpu: 1
gpu: 1
extra_cpu:
eval: 4 + spec.config.num_workers
stop:
episode_reward_mean: -160
config:
use_huber: True
clip_rewards: False
num_workers: 16
n_step: 1
target_network_update_freq: 50000
tau: 1.0

View file

@ -0,0 +1,16 @@
# can expect improvement to -160 reward in ~30-40k timesteps
pendulum-ddpg-2:
env: Pendulum-v0
run: DDPG2
trial_resources:
cpu: 6
gpu: 1
stop:
episode_reward_mean: -160
config:
use_huber: True
random_starts: False
clip_rewards: False
exploration_fraction: 0.4
model:
fcnet_hiddens: []

View file

@ -0,0 +1,16 @@
pendulum-ddpg-2:
env: Pendulum-v0
run: DDPG2
trial_resources:
cpu: 2
stop:
episode_reward_mean: -160
time_total_s: 900
config:
use_huber: True
random_starts: False
clip_rewards: False
exploration_fraction: 0.4
model:
fcnet_hiddens: []
smoothing_num_episodes: 10