mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Contribute DDPG to RLlib (#1877)
* ongoing ddpg * ongoing ddpg converged * gpu machine changes * tuned * tuned ddpg specification * ddpg * supplement missed optimizer argument clip_rewards in default DQN configuration * ddpg supports vision env (atari) now * revised according to code review comments * added regression test case * removed irrelevant files * validate ddpg on mountain_car_continuous * restore unnecessary slight changes * revised according to eric's comments * added the requested tests * revised accordingly * revised accordingly and re-validated * formatted by yapf * fix lint errors * formatted by yapf * fix lint errors * formatted by yapf * fix lint error
This commit is contained in:
parent
aa07f1ce4e
commit
c9a7744e52
16 changed files with 1013 additions and 6 deletions
|
@ -9,7 +9,8 @@ from ray.tune.registry import register_trainable
|
|||
|
||||
def _register_all():
|
||||
for key in ["PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG",
|
||||
"__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
|
||||
"DDPG2", "APEX_DDPG2", "__fake", "__sigmoid_fake_data",
|
||||
"__parameter_tuning"]:
|
||||
from ray.rllib.agent import get_agent_class
|
||||
register_trainable(key, get_agent_class(key))
|
||||
|
||||
|
|
|
@ -231,7 +231,13 @@ class _ParameterTuningAgent(_MockAgent):
|
|||
def get_agent_class(alg):
|
||||
"""Returns the class of a known agent given its name."""
|
||||
|
||||
if alg == "PPO":
|
||||
if alg == "DDPG2":
|
||||
from ray.rllib import ddpg2
|
||||
return ddpg2.DDPG2Agent
|
||||
elif alg == "APEX_DDPG2":
|
||||
from ray.rllib import ddpg2
|
||||
return ddpg2.ApexDDPG2Agent
|
||||
elif alg == "PPO":
|
||||
from ray.rllib import ppo
|
||||
return ppo.PPOAgent
|
||||
elif alg == "ES":
|
||||
|
|
1
python/ray/rllib/ddpg2/README.md
Normal file
1
python/ray/rllib/ddpg2/README.md
Normal file
|
@ -0,0 +1 @@
|
|||
Code in this package follows the style of dqn.
|
8
python/ray/rllib/ddpg2/__init__.py
Normal file
8
python/ray/rllib/ddpg2/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.ddpg2.apex import ApexDDPG2Agent
|
||||
from ray.rllib.ddpg2.ddpg import DDPG2Agent, DEFAULT_CONFIG
|
||||
|
||||
__all__ = ["DDPG2Agent", "ApexDDPG2Agent", "DEFAULT_CONFIG"]
|
47
python/ray/rllib/ddpg2/apex.py
Normal file
47
python/ray/rllib/ddpg2/apex.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.ddpg2.ddpg import DDPG2Agent, DEFAULT_CONFIG as DDPG_CONFIG
|
||||
|
||||
APEX_DDPG_DEFAULT_CONFIG = dict(DDPG_CONFIG,
|
||||
**dict(
|
||||
optimizer_class="ApexOptimizer",
|
||||
optimizer_config=dict(
|
||||
DDPG_CONFIG["optimizer_config"],
|
||||
**dict(
|
||||
max_weight_sync_delay=400,
|
||||
num_replay_buffer_shards=4,
|
||||
debug=False,
|
||||
)),
|
||||
n_step=3,
|
||||
num_workers=32,
|
||||
buffer_size=2000000,
|
||||
learning_starts=50000,
|
||||
train_batch_size=512,
|
||||
sample_batch_size=50,
|
||||
max_weight_sync_delay=400,
|
||||
target_network_update_freq=500000,
|
||||
timesteps_per_iteration=25000,
|
||||
per_worker_exploration=True,
|
||||
worker_side_prioritization=True,
|
||||
))
|
||||
|
||||
|
||||
class ApexDDPG2Agent(DDPG2Agent):
|
||||
"""DDPG variant that uses the Ape-X distributed policy optimizer.
|
||||
|
||||
By default, this is configured for a large single node (32 cores). For
|
||||
running in a large cluster, increase the `num_workers` config var.
|
||||
"""
|
||||
|
||||
_agent_name = "APEX_DDPG"
|
||||
_default_config = APEX_DDPG_DEFAULT_CONFIG
|
||||
|
||||
def update_target_if_needed(self):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.update_target()
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
0
python/ray/rllib/ddpg2/common/__init__.py
Normal file
0
python/ray/rllib/ddpg2/common/__init__.py
Normal file
268
python/ray/rllib/ddpg2/ddpg.py
Normal file
268
python/ray/rllib/ddpg2/ddpg.py
Normal file
|
@ -0,0 +1,268 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pickle
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib import optimizers
|
||||
from ray.rllib.ddpg2.ddpg_evaluator import DDPGEvaluator
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
OPTIMIZER_SHARED_CONFIGS = [
|
||||
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
|
||||
"prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size",
|
||||
"train_batch_size", "learning_starts", "clip_rewards"
|
||||
]
|
||||
|
||||
DEFAULT_CONFIG = dict(
|
||||
# === Model ===
|
||||
# Hidden layer sizes of the policy networks
|
||||
actor_hiddens=[64, 64],
|
||||
# Hidden layer sizes of the policy networks
|
||||
critic_hiddens=[64, 64],
|
||||
# N-step Q learning
|
||||
n_step=1,
|
||||
# Config options to pass to the model constructor
|
||||
model={},
|
||||
# Discount factor for the MDP
|
||||
gamma=0.99,
|
||||
# Arguments to pass to the env creator
|
||||
env_config={},
|
||||
|
||||
# === Exploration ===
|
||||
# Max num timesteps for annealing schedules. Exploration is annealed from
|
||||
# 1.0 to exploration_fraction over this number of timesteps scaled by
|
||||
# exploration_fraction
|
||||
schedule_max_timesteps=100000,
|
||||
# Number of env steps to optimize for before returning
|
||||
timesteps_per_iteration=1000,
|
||||
# Fraction of entire training period over which the exploration rate is
|
||||
# annealed
|
||||
exploration_fraction=0.1,
|
||||
# Final value of random action probability
|
||||
exploration_final_eps=0.02,
|
||||
# OU-noise scale
|
||||
noise_scale=0.1,
|
||||
# theta
|
||||
exploration_theta=0.15,
|
||||
# sigma
|
||||
exploration_sigma=0.2,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
target_network_update_freq=0,
|
||||
# Update the target by \tau * policy + (1-\tau) * target_policy
|
||||
tau=0.002,
|
||||
# Whether to start with random actions instead of noops.
|
||||
random_starts=True,
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
# each worker will have a replay buffer of this size.
|
||||
buffer_size=50000,
|
||||
# If True prioritized replay buffer will be used.
|
||||
prioritized_replay=True,
|
||||
# Alpha parameter for prioritized replay buffer.
|
||||
prioritized_replay_alpha=0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
prioritized_replay_beta=0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
prioritized_replay_eps=1e-6,
|
||||
# Whether to clip rewards to [-1, 1] prior to adding to the replay buffer.
|
||||
clip_rewards=True,
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for adam optimizer
|
||||
actor_lr=1e-4,
|
||||
critic_lr=1e-3,
|
||||
# If True, use huber loss instead of squared loss for critic network
|
||||
# Conventionally, no need to clip gradients if using a huber loss
|
||||
use_huber=False,
|
||||
# Threshold of a huber loss
|
||||
huber_threshold=1.0,
|
||||
# Weights for L2 regularization
|
||||
l2_reg=1e-6,
|
||||
# If not None, clip gradients during optimization at this value
|
||||
grad_norm_clipping=None,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
learning_starts=1500,
|
||||
# Update the replay buffer with this many samples at once. Note that this
|
||||
# setting applies per-worker if num_workers > 1.
|
||||
sample_batch_size=1,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
train_batch_size=256,
|
||||
# Smooth the current average reward over this many previous episodes.
|
||||
smoothing_num_episodes=100,
|
||||
|
||||
# === Tensorflow ===
|
||||
# Arguments to pass to tensorflow
|
||||
tf_session_args={
|
||||
"device_count": {
|
||||
"CPU": 2
|
||||
},
|
||||
"log_device_placement": False,
|
||||
"allow_soft_placement": True,
|
||||
"gpu_options": {
|
||||
"allow_growth": True
|
||||
},
|
||||
"inter_op_parallelism_threads": 1,
|
||||
"intra_op_parallelism_threads": 1,
|
||||
},
|
||||
|
||||
# === Parallelism ===
|
||||
# Number of workers for collecting samples with. This only makes sense
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you're using the Async or Ape-X optimizers.
|
||||
num_workers=0,
|
||||
# Whether to allocate GPUs for workers (if > 0).
|
||||
num_gpus_per_worker=0,
|
||||
# Optimizer class to use.
|
||||
optimizer_class="LocalSyncReplayOptimizer",
|
||||
# Config to pass to the optimizer.
|
||||
optimizer_config=dict(),
|
||||
# Whether to use a distribution of epsilons across workers for exploration.
|
||||
per_worker_exploration=False,
|
||||
# Whether to compute priorities on workers.
|
||||
worker_side_prioritization=False)
|
||||
|
||||
|
||||
class DDPG2Agent(Agent):
|
||||
_agent_name = "DDPG2"
|
||||
_allow_unknown_subkeys = [
|
||||
"model", "optimizer", "tf_session_args", "env_config"
|
||||
]
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _init(self):
|
||||
self.local_evaluator = DDPGEvaluator(self.registry, self.env_creator,
|
||||
self.config, self.logdir, 0)
|
||||
remote_cls = ray.remote(
|
||||
num_cpus=1,
|
||||
num_gpus=self.config["num_gpus_per_worker"])(DDPGEvaluator)
|
||||
self.remote_evaluators = [
|
||||
remote_cls.remote(self.registry, self.env_creator, self.config,
|
||||
self.logdir, i)
|
||||
for i in range(self.config["num_workers"])
|
||||
]
|
||||
|
||||
for k in OPTIMIZER_SHARED_CONFIGS:
|
||||
if k not in self.config["optimizer_config"]:
|
||||
self.config["optimizer_config"][k] = self.config[k]
|
||||
|
||||
self.optimizer = getattr(optimizers, self.config["optimizer_class"])(
|
||||
self.config["optimizer_config"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
|
||||
self.saver = tf.train.Saver(max_to_keep=None)
|
||||
self.last_target_update_ts = 0
|
||||
self.num_target_updates = 0
|
||||
|
||||
@property
|
||||
def global_timestep(self):
|
||||
return self.optimizer.num_steps_sampled
|
||||
|
||||
def update_target_if_needed(self):
|
||||
if self.global_timestep - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.update_target()
|
||||
self.last_target_update_ts = self.global_timestep
|
||||
self.num_target_updates += 1
|
||||
|
||||
def _train(self):
|
||||
start_timestep = self.global_timestep
|
||||
|
||||
while (self.global_timestep - start_timestep <
|
||||
self.config["timesteps_per_iteration"]):
|
||||
|
||||
self.optimizer.step()
|
||||
self.update_target_if_needed()
|
||||
|
||||
self.local_evaluator.set_global_timestep(self.global_timestep)
|
||||
for e in self.remote_evaluators:
|
||||
e.set_global_timestep.remote(self.global_timestep)
|
||||
|
||||
return self._train_stats(start_timestep)
|
||||
|
||||
def _train_stats(self, start_timestep):
|
||||
if self.remote_evaluators:
|
||||
stats = ray.get([e.stats.remote() for e in self.remote_evaluators])
|
||||
else:
|
||||
stats = self.local_evaluator.stats()
|
||||
if not isinstance(stats, list):
|
||||
stats = [stats]
|
||||
|
||||
mean_100ep_reward = 0.0
|
||||
mean_100ep_length = 0.0
|
||||
num_episodes = 0
|
||||
explorations = []
|
||||
|
||||
if self.config["per_worker_exploration"]:
|
||||
# Return stats from workers with the lowest 20% of exploration
|
||||
test_stats = stats[-int(max(1, len(stats) * 0.2)):]
|
||||
else:
|
||||
test_stats = stats
|
||||
|
||||
for s in test_stats:
|
||||
mean_100ep_reward += s["mean_100ep_reward"] / len(test_stats)
|
||||
mean_100ep_length += s["mean_100ep_length"] / len(test_stats)
|
||||
|
||||
for s in stats:
|
||||
num_episodes += s["num_episodes"]
|
||||
explorations.append(s["exploration"])
|
||||
|
||||
opt_stats = self.optimizer.stats()
|
||||
|
||||
result = TrainingResult(
|
||||
episode_reward_mean=mean_100ep_reward,
|
||||
episode_len_mean=mean_100ep_length,
|
||||
episodes_total=num_episodes,
|
||||
timesteps_this_iter=self.global_timestep - start_timestep,
|
||||
info=dict({
|
||||
"min_exploration": min(explorations),
|
||||
"max_exploration": max(explorations),
|
||||
"num_target_updates": self.num_target_updates,
|
||||
}, **opt_stats))
|
||||
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.local_evaluator.sess,
|
||||
os.path.join(checkpoint_dir, "checkpoint"),
|
||||
global_step=self.iteration)
|
||||
extra_data = [
|
||||
self.local_evaluator.save(),
|
||||
ray.get([e.save.remote() for e in self.remote_evaluators]),
|
||||
self.optimizer.save(), self.num_target_updates,
|
||||
self.last_target_update_ts
|
||||
]
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
self.saver.restore(self.local_evaluator.sess, checkpoint_path)
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
self.local_evaluator.restore(extra_data[0])
|
||||
ray.get([
|
||||
e.restore.remote(d)
|
||||
for (d, e) in zip(extra_data[1], self.remote_evaluators)
|
||||
])
|
||||
self.optimizer.restore(extra_data[2])
|
||||
self.num_target_updates = extra_data[3]
|
||||
self.last_target_update_ts = extra_data[4]
|
||||
|
||||
def compute_action(self, observation):
|
||||
return self.local_evaluator.ddpg_graph.act(self.local_evaluator.sess,
|
||||
np.array(observation)[None],
|
||||
0.0)[0]
|
186
python/ray/rllib/ddpg2/ddpg_evaluator.py
Normal file
186
python/ray/rllib/ddpg2/ddpg_evaluator.py
Normal file
|
@ -0,0 +1,186 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.ddpg2 import models
|
||||
from ray.rllib.dqn.common.schedules import ConstantSchedule, LinearSchedule
|
||||
from ray.rllib.optimizers import SampleBatch, PolicyEvaluator
|
||||
from ray.rllib.utils.compression import pack
|
||||
from ray.rllib.dqn.dqn_evaluator import adjust_nstep
|
||||
from ray.rllib.dqn.common.wrappers import wrap_dqn
|
||||
|
||||
|
||||
class DDPGEvaluator(PolicyEvaluator):
|
||||
"""The base DDPG Evaluator."""
|
||||
|
||||
def __init__(self, registry, env_creator, config, logdir, worker_index):
|
||||
env = env_creator(config["env_config"])
|
||||
env = wrap_dqn(registry, env, config["model"], config["random_starts"])
|
||||
self.env = env
|
||||
self.config = config
|
||||
|
||||
# when env.action_space is of Box type, e.g., Pendulum-v0
|
||||
# action_space.low is [-2.0], high is [2.0]
|
||||
# take action by calling, e.g., env.step([3.5])
|
||||
if not isinstance(env.action_space, Box):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DDPG.".format(
|
||||
env.action_space))
|
||||
|
||||
tf_config = tf.ConfigProto(**config["tf_session_args"])
|
||||
self.sess = tf.Session(config=tf_config)
|
||||
self.ddpg_graph = models.DDPGGraph(registry, env, config, logdir)
|
||||
|
||||
# Use either a different `eps` per worker, or a linear schedule.
|
||||
if config["per_worker_exploration"]:
|
||||
assert config["num_workers"] > 1, "This requires multiple workers"
|
||||
self.exploration = ConstantSchedule(
|
||||
config["noise_scale"] * 0.4 **
|
||||
(1 + worker_index / float(config["num_workers"] - 1) * 7))
|
||||
else:
|
||||
self.exploration = LinearSchedule(
|
||||
schedule_timesteps=int(config["exploration_fraction"] *
|
||||
config["schedule_max_timesteps"]),
|
||||
initial_p=config["noise_scale"] * 1.0,
|
||||
final_p=config["noise_scale"] *
|
||||
config["exploration_final_eps"])
|
||||
|
||||
# Initialize the parameters and copy them to the target network.
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
# hard instead of soft
|
||||
self.ddpg_graph.update_target(self.sess, 1.0)
|
||||
self.global_timestep = 0
|
||||
self.local_timestep = 0
|
||||
|
||||
# Note that this encompasses both the policy and Q-value networks and
|
||||
# their corresponding target networks
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
tf.group(self.ddpg_graph.q_tp0, self.ddpg_graph.q_tp1), self.sess)
|
||||
|
||||
self.episode_rewards = [0.0]
|
||||
self.episode_lengths = [0.0]
|
||||
self.saved_mean_reward = None
|
||||
|
||||
self.obs = self.env.reset()
|
||||
|
||||
def set_global_timestep(self, global_timestep):
|
||||
self.global_timestep = global_timestep
|
||||
|
||||
def update_target(self):
|
||||
self.ddpg_graph.update_target(self.sess)
|
||||
|
||||
def sample(self):
|
||||
obs, actions, rewards, new_obs, dones = [], [], [], [], []
|
||||
for _ in range(
|
||||
self.config["sample_batch_size"] + self.config["n_step"] - 1):
|
||||
ob, act, rew, ob1, done = self._step(self.global_timestep)
|
||||
obs.append(ob)
|
||||
actions.append(act)
|
||||
rewards.append(rew)
|
||||
new_obs.append(ob1)
|
||||
dones.append(done)
|
||||
|
||||
# N-step Q adjustments
|
||||
if self.config["n_step"] > 1:
|
||||
# Adjust for steps lost from truncation
|
||||
self.local_timestep -= (self.config["n_step"] - 1)
|
||||
adjust_nstep(self.config["n_step"], self.config["gamma"], obs,
|
||||
actions, rewards, new_obs, dones)
|
||||
|
||||
batch = SampleBatch({
|
||||
"obs": [pack(np.array(o)) for o in obs],
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": [pack(np.array(o)) for o in new_obs],
|
||||
"dones": dones,
|
||||
"weights": np.ones_like(rewards)
|
||||
})
|
||||
assert (batch.count == self.config["sample_batch_size"])
|
||||
|
||||
# Prioritize on the worker side
|
||||
if self.config["worker_side_prioritization"]:
|
||||
td_errors = self.ddpg_graph.compute_td_error(
|
||||
self.sess, obs, batch["actions"], batch["rewards"], new_obs,
|
||||
batch["dones"], batch["weights"])
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + self.config["prioritized_replay_eps"])
|
||||
batch.data["weights"] = new_priorities
|
||||
|
||||
return batch
|
||||
|
||||
def compute_gradients(self, samples):
|
||||
td_err, grads = self.ddpg_graph.compute_gradients(
|
||||
self.sess, samples["obs"], samples["actions"], samples["rewards"],
|
||||
samples["new_obs"], samples["dones"], samples["weights"])
|
||||
return grads, {"td_error": td_err}
|
||||
|
||||
def apply_gradients(self, grads):
|
||||
self.ddpg_graph.apply_gradients(self.sess, grads)
|
||||
|
||||
def compute_apply(self, samples):
|
||||
td_error = self.ddpg_graph.compute_apply(
|
||||
self.sess, samples["obs"], samples["actions"], samples["rewards"],
|
||||
samples["new_obs"], samples["dones"], samples["weights"])
|
||||
return {"td_error": td_error}
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
def _step(self, global_timestep):
|
||||
"""Takes a single step, and returns the result of the step."""
|
||||
action = self.ddpg_graph.act(
|
||||
self.sess,
|
||||
np.array(self.obs)[None],
|
||||
self.exploration.value(global_timestep))[0]
|
||||
new_obs, rew, done, _ = self.env.step(action)
|
||||
ret = (self.obs, action, rew, new_obs, float(done))
|
||||
self.obs = new_obs
|
||||
self.episode_rewards[-1] += rew
|
||||
self.episode_lengths[-1] += 1
|
||||
if done:
|
||||
self.obs = self.env.reset()
|
||||
self.episode_rewards.append(0.0)
|
||||
self.episode_lengths.append(0.0)
|
||||
# reset UO noise for each episode
|
||||
self.ddpg_graph.reset_noise(self.sess)
|
||||
|
||||
self.local_timestep += 1
|
||||
return ret
|
||||
|
||||
def stats(self):
|
||||
n = self.config["smoothing_num_episodes"] + 1
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-n:-1]), 5)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-n:-1]), 5)
|
||||
exploration = self.exploration.value(self.global_timestep)
|
||||
return {
|
||||
"mean_100ep_reward": mean_100ep_reward,
|
||||
"mean_100ep_length": mean_100ep_length,
|
||||
"num_episodes": len(self.episode_rewards),
|
||||
"exploration": exploration,
|
||||
"local_timestep": self.local_timestep,
|
||||
}
|
||||
|
||||
def save(self):
|
||||
return [
|
||||
self.exploration, self.episode_rewards, self.episode_lengths,
|
||||
self.saved_mean_reward, self.obs, self.global_timestep,
|
||||
self.local_timestep
|
||||
]
|
||||
|
||||
def restore(self, data):
|
||||
self.exploration = data[0]
|
||||
self.episode_rewards = data[1]
|
||||
self.episode_lengths = data[2]
|
||||
self.saved_mean_reward = data[3]
|
||||
self.obs = data[4]
|
||||
self.global_timestep = data[5]
|
||||
self.local_timestep = data[6]
|
391
python/ray/rllib/ddpg2/models.py
Normal file
391
python/ray/rllib/ddpg2/models.py
Normal file
|
@ -0,0 +1,391 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
|
||||
def _build_p_network(registry, inputs, dim_actions, config):
|
||||
"""
|
||||
map an observation (i.e., state) to an action where
|
||||
each entry takes value from (0, 1) due to the sigmoid function
|
||||
"""
|
||||
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
|
||||
|
||||
hiddens = config["actor_hiddens"]
|
||||
action_out = frontend.last_layer
|
||||
for hidden in hiddens:
|
||||
action_out = layers.fully_connected(
|
||||
action_out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
||||
# Use sigmoid layer to bound values within (0, 1)
|
||||
# shape of action_scores is [batch_size, dim_actions]
|
||||
action_scores = layers.fully_connected(
|
||||
action_out, num_outputs=dim_actions, activation_fn=tf.nn.sigmoid)
|
||||
|
||||
return action_scores
|
||||
|
||||
|
||||
# As a stochastic policy for inference, but a deterministic policy for training
|
||||
# thus ignore batch_size issue when constructing a stochastic action
|
||||
def _build_action_network(p_values, low_action, high_action, stochastic, eps,
|
||||
theta, sigma):
|
||||
# shape is [None, dim_action]
|
||||
deterministic_actions = (high_action - low_action) * p_values + low_action
|
||||
|
||||
exploration_sample = tf.get_variable(
|
||||
name="ornstein_uhlenbeck",
|
||||
dtype=tf.float32,
|
||||
initializer=low_action.size * [.0],
|
||||
trainable=False)
|
||||
normal_sample = tf.random_normal(
|
||||
shape=[low_action.size], mean=0.0, stddev=1.0)
|
||||
exploration_value = tf.assign_add(
|
||||
exploration_sample,
|
||||
theta * (.0 - exploration_sample) + sigma * normal_sample)
|
||||
stochastic_actions = deterministic_actions + eps * (
|
||||
high_action - low_action) * exploration_value
|
||||
|
||||
return tf.cond(stochastic, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
|
||||
|
||||
def _build_q_network(registry, inputs, action_inputs, config):
|
||||
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
|
||||
|
||||
hiddens = config["critic_hiddens"]
|
||||
|
||||
q_out = tf.concat([frontend.last_layer, action_inputs], axis=1)
|
||||
for hidden in hiddens:
|
||||
q_out = layers.fully_connected(
|
||||
q_out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
||||
q_scores = layers.fully_connected(q_out, num_outputs=1, activation_fn=None)
|
||||
|
||||
return q_scores
|
||||
|
||||
|
||||
def _huber_loss(x, delta=1.0):
|
||||
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
||||
return tf.where(
|
||||
tf.abs(x) < delta,
|
||||
tf.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta))
|
||||
|
||||
|
||||
def _minimize_and_clip(optimizer, objective, var_list, clip_val=10):
|
||||
"""Minimized `objective` using `optimizer` w.r.t. variables in
|
||||
`var_list` while ensure the norm of the gradients for each
|
||||
variable is clipped to `clip_val`
|
||||
"""
|
||||
gradients = optimizer.compute_gradients(objective, var_list=var_list)
|
||||
for i, (grad, var) in enumerate(gradients):
|
||||
if grad is not None:
|
||||
gradients[i] = (tf.clip_by_norm(grad, clip_val), var)
|
||||
return gradients
|
||||
|
||||
|
||||
def _scope_vars(scope, trainable_only=False):
|
||||
"""
|
||||
Get variables inside a scope
|
||||
The scope can be specified as a string
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scope: str or VariableScope
|
||||
scope in which the variables reside.
|
||||
trainable_only: bool
|
||||
whether or not to return only the variables that were marked as
|
||||
trainable.
|
||||
|
||||
Returns
|
||||
-------
|
||||
vars: [tf.Variable]
|
||||
list of variables in `scope`.
|
||||
"""
|
||||
return tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES
|
||||
if trainable_only else tf.GraphKeys.VARIABLES,
|
||||
scope=scope if isinstance(scope, str) else scope.name)
|
||||
|
||||
|
||||
class ModelAndLoss(object):
|
||||
"""Holds the model and loss function.
|
||||
|
||||
Both graphs are necessary in order for the multi-gpu SGD implementation
|
||||
to create towers on each device.
|
||||
"""
|
||||
|
||||
def __init__(self, registry, dim_actions, low_action, high_action, config,
|
||||
obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights):
|
||||
# p network evaluation
|
||||
with tf.variable_scope("p_func", reuse=True) as scope:
|
||||
self.p_t = _build_p_network(registry, obs_t, dim_actions, config)
|
||||
|
||||
# target p network evaluation
|
||||
with tf.variable_scope("target_p_func") as scope:
|
||||
self.p_tp1 = _build_p_network(registry, obs_tp1, dim_actions,
|
||||
config)
|
||||
self.target_p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
with tf.variable_scope("a_func", reuse=True):
|
||||
deterministic_flag = tf.constant(value=False, dtype=tf.bool)
|
||||
zero_eps = tf.constant(value=.0, dtype=tf.float32)
|
||||
output_actions = _build_action_network(
|
||||
self.p_t, low_action, high_action, deterministic_flag,
|
||||
zero_eps, config["exploration_theta"],
|
||||
config["exploration_sigma"])
|
||||
|
||||
output_actions_estimated = _build_action_network(
|
||||
self.p_tp1, low_action, high_action, deterministic_flag,
|
||||
zero_eps, config["exploration_theta"],
|
||||
config["exploration_sigma"])
|
||||
|
||||
# q network evaluation
|
||||
with tf.variable_scope("q_func") as scope:
|
||||
self.q_t = _build_q_network(registry, obs_t, act_t, config)
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
with tf.variable_scope("q_func", reuse=True):
|
||||
self.q_tp0 = _build_q_network(registry, obs_t, output_actions,
|
||||
config)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope("target_q_func") as scope:
|
||||
self.q_tp1 = _build_q_network(registry, obs_tp1,
|
||||
output_actions_estimated, config)
|
||||
self.target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
q_t_selected = tf.squeeze(self.q_t, axis=len(self.q_t.shape) - 1)
|
||||
|
||||
q_tp1_best = tf.squeeze(
|
||||
input=self.q_tp1, axis=len(self.q_tp1.shape) - 1)
|
||||
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
||||
|
||||
# compute RHS of bellman equation
|
||||
q_t_selected_target = (
|
||||
rew_t + config["gamma"]**config["n_step"] * q_tp1_best_masked)
|
||||
|
||||
# compute the error (potentially clipped)
|
||||
self.td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
||||
if config.get("use_huber"):
|
||||
errors = _huber_loss(self.td_error, config.get("huber_threshold"))
|
||||
else:
|
||||
errors = 0.5 * tf.square(self.td_error)
|
||||
|
||||
weighted_error = tf.reduce_mean(importance_weights * errors)
|
||||
|
||||
self.loss = weighted_error
|
||||
|
||||
# for policy gradient
|
||||
self.actor_loss = -1.0 * tf.reduce_mean(self.q_tp0)
|
||||
|
||||
|
||||
class DDPGGraph(object):
|
||||
def __init__(self, registry, env, config, logdir):
|
||||
self.env = env
|
||||
dim_actions = env.action_space.shape[0]
|
||||
low_action = env.action_space.low
|
||||
high_action = env.action_space.high
|
||||
actor_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["actor_lr"])
|
||||
critic_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["critic_lr"])
|
||||
|
||||
# Action inputs
|
||||
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
|
||||
self.eps = tf.placeholder(tf.float32, (), name="eps")
|
||||
self.cur_observations = tf.placeholder(
|
||||
tf.float32, shape=(None, ) + env.observation_space.shape)
|
||||
|
||||
# Actor: P (policy) network
|
||||
p_scope_name = "p_func"
|
||||
with tf.variable_scope(p_scope_name) as scope:
|
||||
p_values = _build_p_network(registry, self.cur_observations,
|
||||
dim_actions, config)
|
||||
p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
a_scope_name = "a_func"
|
||||
with tf.variable_scope(a_scope_name):
|
||||
self.output_actions = _build_action_network(
|
||||
p_values, low_action, high_action, self.stochastic, self.eps,
|
||||
config["exploration_theta"], config["exploration_sigma"])
|
||||
|
||||
with tf.variable_scope(a_scope_name, reuse=True):
|
||||
exploration_sample = tf.get_variable(name="ornstein_uhlenbeck")
|
||||
self.reset_noise_op = tf.assign(exploration_sample,
|
||||
dim_actions * [.0])
|
||||
|
||||
# Replay inputs
|
||||
self.obs_t = tf.placeholder(
|
||||
tf.float32,
|
||||
shape=(None, ) + env.observation_space.shape,
|
||||
name="observation")
|
||||
self.act_t = tf.placeholder(
|
||||
tf.float32, shape=(None, ) + env.action_space.shape, name="action")
|
||||
self.rew_t = tf.placeholder(tf.float32, [None], name="reward")
|
||||
self.obs_tp1 = tf.placeholder(
|
||||
tf.float32, shape=(None, ) + env.observation_space.shape)
|
||||
self.done_mask = tf.placeholder(tf.float32, [None], name="done")
|
||||
self.importance_weights = tf.placeholder(
|
||||
tf.float32, [None], name="weight")
|
||||
|
||||
def build_loss(obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
return ModelAndLoss(registry, dim_actions, low_action, high_action,
|
||||
config, obs_t, act_t, rew_t, obs_tp1,
|
||||
done_mask, importance_weights)
|
||||
|
||||
self.loss_inputs = [
|
||||
("obs", self.obs_t),
|
||||
("actions", self.act_t),
|
||||
("rewards", self.rew_t),
|
||||
("new_obs", self.obs_tp1),
|
||||
("dones", self.done_mask),
|
||||
("weights", self.importance_weights),
|
||||
]
|
||||
|
||||
loss_obj = build_loss(self.obs_t, self.act_t, self.rew_t, self.obs_tp1,
|
||||
self.done_mask, self.importance_weights)
|
||||
|
||||
self.build_loss = build_loss
|
||||
|
||||
actor_loss = loss_obj.actor_loss
|
||||
weighted_error = loss_obj.loss
|
||||
q_func_vars = loss_obj.q_func_vars
|
||||
target_p_func_vars = loss_obj.target_p_func_vars
|
||||
target_q_func_vars = loss_obj.target_q_func_vars
|
||||
self.p_t = loss_obj.p_t
|
||||
self.q_t = loss_obj.q_t
|
||||
self.q_tp0 = loss_obj.q_tp0
|
||||
self.q_tp1 = loss_obj.q_tp1
|
||||
self.td_error = loss_obj.td_error
|
||||
|
||||
if config["l2_reg"] is not None:
|
||||
for var in p_func_vars:
|
||||
if "bias" not in var.name:
|
||||
actor_loss += config["l2_reg"] * 0.5 * tf.nn.l2_loss(var)
|
||||
for var in q_func_vars:
|
||||
if "bias" not in var.name:
|
||||
weighted_error += config["l2_reg"] * 0.5 * tf.nn.l2_loss(
|
||||
var)
|
||||
|
||||
# compute optimization op (potentially with gradient clipping)
|
||||
if config["grad_norm_clipping"] is not None:
|
||||
self.actor_grads_and_vars = _minimize_and_clip(
|
||||
actor_optimizer,
|
||||
actor_loss,
|
||||
var_list=p_func_vars,
|
||||
clip_val=config["grad_norm_clipping"])
|
||||
self.critic_grads_and_vars = _minimize_and_clip(
|
||||
critic_optimizer,
|
||||
weighted_error,
|
||||
var_list=q_func_vars,
|
||||
clip_val=config["grad_norm_clipping"])
|
||||
else:
|
||||
self.actor_grads_and_vars = actor_optimizer.compute_gradients(
|
||||
actor_loss, var_list=p_func_vars)
|
||||
self.critic_grads_and_vars = critic_optimizer.compute_gradients(
|
||||
weighted_error, var_list=q_func_vars)
|
||||
self.actor_grads_and_vars = [(g, v)
|
||||
for (g, v) in self.actor_grads_and_vars
|
||||
if g is not None]
|
||||
self.critic_grads_and_vars = [(g, v)
|
||||
for (g, v) in self.critic_grads_and_vars
|
||||
if g is not None]
|
||||
self.grads_and_vars = (
|
||||
self.actor_grads_and_vars + self.critic_grads_and_vars)
|
||||
self.grads = [g for (g, v) in self.grads_and_vars]
|
||||
self.actor_train_expr = actor_optimizer.apply_gradients(
|
||||
self.actor_grads_and_vars)
|
||||
self.critic_train_expr = critic_optimizer.apply_gradients(
|
||||
self.critic_grads_and_vars)
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
self.tau_value = config.get("tau")
|
||||
self.tau = tf.placeholder(tf.float32, (), name="tau")
|
||||
update_target_expr = []
|
||||
for var, var_target in zip(
|
||||
sorted(q_func_vars, key=lambda v: v.name),
|
||||
sorted(target_q_func_vars, key=lambda v: v.name)):
|
||||
update_target_expr.append(
|
||||
var_target.assign(self.tau * var +
|
||||
(1.0 - self.tau) * var_target))
|
||||
for var, var_target in zip(
|
||||
sorted(p_func_vars, key=lambda v: v.name),
|
||||
sorted(target_p_func_vars, key=lambda v: v.name)):
|
||||
update_target_expr.append(
|
||||
var_target.assign(self.tau * var +
|
||||
(1.0 - self.tau) * var_target))
|
||||
self.update_target_expr = tf.group(*update_target_expr)
|
||||
|
||||
# support both hard and soft sync
|
||||
def update_target(self, sess, tau=None):
|
||||
return sess.run(
|
||||
self.update_target_expr,
|
||||
feed_dict={self.tau: tau or self.tau_value})
|
||||
|
||||
def act(self, sess, obs, eps, stochastic=True):
|
||||
return sess.run(
|
||||
self.output_actions,
|
||||
feed_dict={
|
||||
self.cur_observations: obs,
|
||||
self.stochastic: stochastic,
|
||||
self.eps: eps
|
||||
})
|
||||
|
||||
def compute_gradients(self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
td_err, grads = sess.run(
|
||||
[self.td_error, self.grads],
|
||||
feed_dict={
|
||||
self.obs_t: obs_t,
|
||||
self.act_t: act_t,
|
||||
self.rew_t: rew_t,
|
||||
self.obs_tp1: obs_tp1,
|
||||
self.done_mask: done_mask,
|
||||
self.importance_weights: importance_weights
|
||||
})
|
||||
return td_err, grads
|
||||
|
||||
def compute_td_error(self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
td_err = sess.run(
|
||||
self.td_error,
|
||||
feed_dict={
|
||||
self.obs_t: [np.array(ob) for ob in obs_t],
|
||||
self.act_t: act_t,
|
||||
self.rew_t: rew_t,
|
||||
self.obs_tp1: [np.array(ob) for ob in obs_tp1],
|
||||
self.done_mask: done_mask,
|
||||
self.importance_weights: importance_weights
|
||||
})
|
||||
return td_err
|
||||
|
||||
def apply_gradients(self, sess, grads):
|
||||
assert len(grads) == len(self.grads_and_vars)
|
||||
feed_dict = {ph: g for (g, ph) in zip(grads, self.grads)}
|
||||
sess.run(
|
||||
[self.critic_train_expr, self.actor_train_expr],
|
||||
feed_dict=feed_dict)
|
||||
|
||||
def compute_apply(self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
td_err, _, _ = sess.run(
|
||||
[self.td_error, self.critic_train_expr, self.actor_train_expr],
|
||||
feed_dict={
|
||||
self.obs_t: obs_t,
|
||||
self.act_t: act_t,
|
||||
self.rew_t: rew_t,
|
||||
self.obs_tp1: obs_tp1,
|
||||
self.done_mask: done_mask,
|
||||
self.importance_weights: importance_weights
|
||||
})
|
||||
return td_err
|
||||
|
||||
def reset_noise(self, sess):
|
||||
sess.run(self.reset_noise_op)
|
|
@ -22,6 +22,7 @@ ray.init()
|
|||
CONFIGS = {
|
||||
"ES": {"episodes_per_batch": 10, "timesteps_per_batch": 100},
|
||||
"DQN": {},
|
||||
"DDPG2": {"noise_scale": 0.0},
|
||||
"PPO": {"num_sgd_iter": 5, "timesteps_per_batch": 1000},
|
||||
"A3C": {"use_lstm": False},
|
||||
}
|
||||
|
@ -29,6 +30,10 @@ CONFIGS = {
|
|||
|
||||
def test(use_object_store, alg_name):
|
||||
cls = get_agent_class(alg_name)
|
||||
if alg_name == "DDPG2":
|
||||
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
alg2 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
else:
|
||||
alg1 = cls(config=CONFIGS[name], env="CartPole-v0")
|
||||
alg2 = cls(config=CONFIGS[name], env="CartPole-v0")
|
||||
|
||||
|
@ -43,6 +48,9 @@ def test(use_object_store, alg_name):
|
|||
alg2.restore(alg1.save())
|
||||
|
||||
for _ in range(10):
|
||||
if alg_name == "DDPG2":
|
||||
obs = np.random.uniform(size=3)
|
||||
else:
|
||||
obs = np.random.uniform(size=4)
|
||||
a1 = get_mean_action(alg1, obs)
|
||||
a2 = get_mean_action(alg2, obs)
|
||||
|
@ -53,7 +61,7 @@ def test(use_object_store, alg_name):
|
|||
if __name__ == "__main__":
|
||||
# https://github.com/ray-project/ray/issues/1062 for enabling ES test too
|
||||
for use_object_store in [False, True]:
|
||||
for name in ["ES", "DQN", "PPO", "A3C"]:
|
||||
for name in ["ES", "DQN", "DDPG2", "PPO", "A3C"]:
|
||||
test(use_object_store, name)
|
||||
|
||||
print("All checkpoint restore tests passed!")
|
||||
|
|
|
@ -114,6 +114,7 @@ class ModelSupportedSpaces(unittest.TestCase):
|
|||
def testAll(self):
|
||||
ray.init()
|
||||
stats = {}
|
||||
check_support("DDPG2", {"timesteps_per_iteration": 1}, stats)
|
||||
check_support("DQN", {"timesteps_per_iteration": 1}, stats)
|
||||
check_support(
|
||||
"A3C", {"num_workers": 1, "optimizer": {"grads_per_step": 1}},
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# This can be expected to reach 90 reward within ~1.5-2.5m timesteps / ~150-250 seconds on a K40 GPU
|
||||
mountaincarcontinuous-apex-ddpg-2:
|
||||
env: MountainCarContinuous-v0
|
||||
run: APEX_DDPG2
|
||||
trial_resources:
|
||||
cpu: 1
|
||||
gpu: 1
|
||||
extra_cpu:
|
||||
eval: 4 + spec.config.num_workers
|
||||
stop:
|
||||
episode_reward_mean: 90
|
||||
config:
|
||||
clip_rewards: False
|
||||
num_workers: 16
|
||||
noise_scale: 1.0
|
||||
n_step: 3
|
||||
target_network_update_freq: 50000
|
||||
tau: 1.0
|
|
@ -0,0 +1,22 @@
|
|||
# can expect improvement to 90 reward in ~12-24k timesteps
|
||||
mountaincarcontinuous-ddpg-2:
|
||||
env: MountainCarContinuous-v0
|
||||
run: DDPG2
|
||||
trial_resources:
|
||||
cpu: 6
|
||||
stop:
|
||||
episode_reward_mean: 90
|
||||
config:
|
||||
n_step: 3
|
||||
actor_hiddens: [32, 64]
|
||||
critic_hiddens: [64, 64]
|
||||
noise_scale: 0.75
|
||||
exploration_fraction: 0.4
|
||||
tau: 0.01
|
||||
l2_reg: 0.00001
|
||||
buffer_size: 50000
|
||||
random_starts: False
|
||||
clip_rewards: False
|
||||
learning_starts: 1000
|
||||
#model:
|
||||
# fcnet_hiddens: []
|
18
python/ray/rllib/tuned_examples/pendulum-apex-ddpg2.yaml
Normal file
18
python/ray/rllib/tuned_examples/pendulum-apex-ddpg2.yaml
Normal file
|
@ -0,0 +1,18 @@
|
|||
# This can be expected to reach -160 reward within 2.5 timesteps / ~250 seconds on a K40 GPU
|
||||
pendulum-apex-ddpg-2:
|
||||
env: Pendulum-v0
|
||||
run: APEX_DDPG2
|
||||
trial_resources:
|
||||
cpu: 1
|
||||
gpu: 1
|
||||
extra_cpu:
|
||||
eval: 4 + spec.config.num_workers
|
||||
stop:
|
||||
episode_reward_mean: -160
|
||||
config:
|
||||
use_huber: True
|
||||
clip_rewards: False
|
||||
num_workers: 16
|
||||
n_step: 1
|
||||
target_network_update_freq: 50000
|
||||
tau: 1.0
|
16
python/ray/rllib/tuned_examples/pendulum-ddpg2.yaml
Normal file
16
python/ray/rllib/tuned_examples/pendulum-ddpg2.yaml
Normal file
|
@ -0,0 +1,16 @@
|
|||
# can expect improvement to -160 reward in ~30-40k timesteps
|
||||
pendulum-ddpg-2:
|
||||
env: Pendulum-v0
|
||||
run: DDPG2
|
||||
trial_resources:
|
||||
cpu: 6
|
||||
gpu: 1
|
||||
stop:
|
||||
episode_reward_mean: -160
|
||||
config:
|
||||
use_huber: True
|
||||
random_starts: False
|
||||
clip_rewards: False
|
||||
exploration_fraction: 0.4
|
||||
model:
|
||||
fcnet_hiddens: []
|
|
@ -0,0 +1,16 @@
|
|||
pendulum-ddpg-2:
|
||||
env: Pendulum-v0
|
||||
run: DDPG2
|
||||
trial_resources:
|
||||
cpu: 2
|
||||
stop:
|
||||
episode_reward_mean: -160
|
||||
time_total_s: 900
|
||||
config:
|
||||
use_huber: True
|
||||
random_starts: False
|
||||
clip_rewards: False
|
||||
exploration_fraction: 0.4
|
||||
model:
|
||||
fcnet_hiddens: []
|
||||
smoothing_num_episodes: 10
|
Loading…
Add table
Reference in a new issue