ray/rllib/agents/ddpg/ddpg.py

226 lines
9.4 KiB
Python
Raw Normal View History

import logging
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.ddpg.ddpg_tf_policy import DDPGTFPolicy
from ray.rllib.utils.deprecation import deprecation_warning, \
DEPRECATED_VALUE
from ray.rllib.utils.exploration.per_worker_ornstein_uhlenbeck_noise import \
PerWorkerOrnsteinUhlenbeckNoise
2018-04-11 15:08:39 -07:00
logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
[rllib] Document "v2" APIs (#2316) * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * envs * vec * doc prep * models * rl * alg * up * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * merge * wip * fix up * move pg class * rename env * wip * update * tip * alg * readme * fix catalog * readme * doc * context * remove prep * comma * add env * link to paper * paper * update * rnn * update * wip * clean up ev creation * fix * fix * fix * fix lint * up * no comma * ma * Update run_multi_node_tests.sh * fix * sphinx is stupid * sphinx is stupid * clarify torch graph * no horizon * fix config * sb * Update test_optimizers.py
2018-07-01 00:05:08 -07:00
DEFAULT_CONFIG = with_common_config({
# === Twin Delayed DDPG (TD3) and Soft Actor-Critic (SAC) tricks ===
# TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
# In addition to settings below, you can use "exploration_noise_type" and
# "exploration_gauss_act_noise" to get IID Gaussian exploration noise
# instead of OU exploration noise.
# twin Q-net
"twin_q": False,
# delayed policy update
"policy_delay": 1,
# target policy smoothing
# (this also replaces OU exploration noise with IID Gaussian exploration
# noise, for now)
"smooth_target_policy": False,
# gaussian stddev of target action noise for smoothing
"target_noise": 0.2,
# target noise limit (bound)
"target_noise_clip": 0.5,
# === Evaluation ===
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
"evaluation_num_episodes": 10,
# === Model ===
# Apply a state preprocessor with spec given by the "model" config option
# (like other RL algorithms). This is mostly useful if you have a weird
# observation shape, like an image. Disabled by default.
"use_state_preprocessor": False,
# Postprocess the policy network model output with these hidden layers. If
# use_state_preprocessor is False, then these will be the *only* hidden
# layers in the network.
"actor_hiddens": [400, 300],
# Hidden layers activation of the postprocessing stage of the policy
# network
"actor_hidden_activation": "relu",
# Postprocess the critic network model output with these hidden layers;
# again, if use_state_preprocessor is True, then the state will be
# preprocessed by the model specified with the "model" config option first.
"critic_hiddens": [400, 300],
# Hidden layers activation of the postprocessing state of the critic.
"critic_hidden_activation": "relu",
# N-step Q learning
"n_step": 1,
# === Exploration ===
"exploration_config": {
# DDPG uses OrnsteinUhlenbeck (stateful) noise to be added to NN-output
# actions (after a possible pure random phase of n timesteps).
"type": "OrnsteinUhlenbeckNoise",
# For how many timesteps should we return completely random actions,
# before we start adding (scaled) noise?
"random_timesteps": 1000,
# The OU-base scaling factor to always apply to action-added noise.
"ou_base_scale": 0.1,
# The OU theta param.
"ou_theta": 0.15,
# The OU sigma param.
"ou_sigma": 0.2,
# The initial noise scaling factor.
"initial_scale": 1.0,
# The final noise scaling factor.
"final_scale": 1.0,
# Timesteps over which to anneal scale (from initial to final values).
"scale_timesteps": 10000,
},
# Number of env steps to optimize for before returning
"timesteps_per_iteration": 1000,
# Extra configuration that disables exploration.
"evaluation_config": {
"explore": False
},
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": 50000,
# If True prioritized replay buffer will be used.
"prioritized_replay": True,
# Alpha parameter for prioritized replay buffer.
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Time steps over which the beta parameter is annealed.
"prioritized_replay_beta_annealing_timesteps": 20000,
# Final value of beta
"final_prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# Whether to LZ4 compress observations
"compress_observations": False,
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# === Optimization ===
# Learning rate for the critic (Q-function) optimizer.
"critic_lr": 1e-3,
# Learning rate for the actor (policy) optimizer.
"actor_lr": 1e-3,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 0,
# Update the target by \tau * policy + (1-\tau) * target_policy
"tau": 0.002,
# If True, use huber loss instead of squared loss for critic network
# Conventionally, no need to clip gradients if using a huber loss
"use_huber": False,
# Threshold of a huber loss
"huber_threshold": 1.0,
# Weights for L2 regularization
"l2_reg": 1e-6,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": None,
# How many steps of the model to sample before learning starts.
"learning_starts": 1500,
# Update the replay buffer with this many samples at once. Note that this
# setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 1,
# Size of a batched sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 256,
2018-04-11 15:08:39 -07:00
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you're using the Async or Ape-X optimizers.
"num_workers": 0,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# Deprecated keys.
"parameter_noise": DEPRECATED_VALUE,
[rllib] Document "v2" APIs (#2316) * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * envs * vec * doc prep * models * rl * alg * up * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * merge * wip * fix up * move pg class * rename env * wip * update * tip * alg * readme * fix catalog * readme * doc * context * remove prep * comma * add env * link to paper * paper * update * rnn * update * wip * clean up ev creation * fix * fix * fix * fix lint * up * no comma * ma * Update run_multi_node_tests.sh * fix * sphinx is stupid * sphinx is stupid * clarify torch graph * no horizon * fix config * sb * Update test_optimizers.py
2018-07-01 00:05:08 -07:00
})
# __sphinx_doc_end__
# yapf: enable
2018-04-11 15:08:39 -07:00
def validate_config(config):
if config["model"]["custom_model"]:
logger.warning(
"Setting use_state_preprocessor=True since a custom model "
"was specified.")
config["use_state_preprocessor"] = True
# TODO(sven): Remove at some point.
# Backward compatibility of noise-based exploration config.
schedule_max_timesteps = None
if config.get("schedule_max_timesteps", DEPRECATED_VALUE) != \
DEPRECATED_VALUE:
deprecation_warning("schedule_max_timesteps",
"exploration_config.scale_timesteps")
schedule_max_timesteps = config["schedule_max_timesteps"]
if config.get("exploration_final_scale", DEPRECATED_VALUE) != \
DEPRECATED_VALUE:
deprecation_warning("exploration_final_scale",
"exploration_config.final_scale")
if isinstance(config["exploration_config"], dict):
config["exploration_config"]["final_scale"] = \
config.pop("exploration_final_scale")
if config.get("exploration_fraction", DEPRECATED_VALUE) != \
DEPRECATED_VALUE:
assert schedule_max_timesteps is not None
deprecation_warning("exploration_fraction",
"exploration_config.scale_timesteps")
if isinstance(config["exploration_config"], dict):
config["exploration_config"]["scale_timesteps"] = config.pop(
"exploration_fraction") * schedule_max_timesteps
if config.get("per_worker_exploration", DEPRECATED_VALUE) != \
DEPRECATED_VALUE:
deprecation_warning(
"per_worker_exploration",
"exploration_config.type=PerWorkerOrnsteinUhlenbeckNoise")
if isinstance(config["exploration_config"], dict):
config["exploration_config"]["type"] = \
PerWorkerOrnsteinUhlenbeckNoise
if config.get("parameter_noise", DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning("parameter_noise", "exploration_config={"
"type=ParameterNoise}")
if config["exploration_config"]["type"] == "ParameterNoise":
if config["batch_mode"] != "complete_episodes":
logger.warning(
"ParameterNoise Exploration requires `batch_mode` to be "
"'complete_episodes'. Setting batch_mode=complete_episodes.")
config["batch_mode"] = "complete_episodes"
def get_policy_class(config):
if config["framework"] == "torch":
from ray.rllib.agents.ddpg.ddpg_torch_policy import DDPGTorchPolicy
return DDPGTorchPolicy
else:
return DDPGTFPolicy
DDPGTrainer = GenericOffPolicyTrainer.with_updates(
name="DDPG",
default_config=DEFAULT_CONFIG,
default_policy=DDPGTFPolicy,
get_policy_class=get_policy_class,
validate_config=validate_config,
)