2020-09-20 11:27:02 +02:00
|
|
|
"""
|
|
|
|
Soft Actor Critic (SAC)
|
|
|
|
=======================
|
|
|
|
|
|
|
|
This file defines the distributed Trainer class for the soft actor critic
|
|
|
|
algorithm.
|
|
|
|
See `sac_[tf|torch]_policy.py` for the definition of the policy loss.
|
|
|
|
|
2020-11-10 10:53:28 -08:00
|
|
|
Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#sac
|
2020-09-20 11:27:02 +02:00
|
|
|
"""
|
|
|
|
|
2020-06-25 19:01:32 +02:00
|
|
|
import logging
|
2020-09-20 11:27:02 +02:00
|
|
|
from typing import Optional, Type
|
2020-06-25 19:01:32 +02:00
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
|
|
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
|
2020-04-15 13:25:16 +02:00
|
|
|
from ray.rllib.agents.sac.sac_tf_policy import SACTFPolicy
|
2020-09-20 11:27:02 +02:00
|
|
|
from ray.rllib.policy.policy import Policy
|
2021-02-02 13:05:58 +01:00
|
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
2020-09-20 11:27:02 +02:00
|
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-06-25 19:01:32 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
OPTIMIZER_SHARED_CONFIGS = [
|
|
|
|
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
|
2020-03-14 12:05:04 -07:00
|
|
|
"prioritized_replay_beta", "prioritized_replay_eps",
|
|
|
|
"rollout_fragment_length", "train_batch_size", "learning_starts"
|
2019-08-01 23:37:36 -07:00
|
|
|
]
|
|
|
|
|
|
|
|
# yapf: disable
|
|
|
|
# __sphinx_doc_begin__
|
2020-09-20 11:27:02 +02:00
|
|
|
|
|
|
|
# Adds the following updates to the (base) `Trainer` config in
|
|
|
|
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
|
2019-08-01 23:37:36 -07:00
|
|
|
DEFAULT_CONFIG = with_common_config({
|
|
|
|
# === Model ===
|
2020-09-20 11:27:02 +02:00
|
|
|
# Use two Q-networks (instead of one) for action-value estimation.
|
|
|
|
# Note: Each Q-network will have its own target network.
|
2019-08-01 23:37:36 -07:00
|
|
|
"twin_q": True,
|
2020-09-20 11:27:02 +02:00
|
|
|
# Use a e.g. conv2D state preprocessing network before concatenating the
|
|
|
|
# resulting (feature) vector with the action input for the input to
|
|
|
|
# the Q-networks.
|
2021-02-02 13:05:58 +01:00
|
|
|
"use_state_preprocessor": DEPRECATED_VALUE,
|
|
|
|
# Model options for the Q network(s). These will override MODEL_DEFAULTS.
|
|
|
|
# The `Q_model` dict is treated just as the top-level `model` dict in
|
|
|
|
# setting up the Q-network(s) (2 if twin_q=True).
|
|
|
|
# That means, you can do for different observation spaces:
|
|
|
|
# obs=Box(1D) -> Tuple(Box(1D) + Action) -> concat -> post_fcnet
|
|
|
|
# obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action
|
|
|
|
# -> post_fcnet
|
|
|
|
# obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action)
|
|
|
|
# -> vision-net -> concat w/ Box(1D) and action -> post_fcnet
|
|
|
|
# You can also have SAC use your custom_model as Q-model(s), by simply
|
|
|
|
# specifying the `custom_model` sub-key in below dict (just like you would
|
|
|
|
# do in the top-level `model` dict.
|
2019-08-01 23:37:36 -07:00
|
|
|
"Q_model": {
|
2020-04-15 13:25:16 +02:00
|
|
|
"fcnet_hiddens": [256, 256],
|
2021-02-02 13:05:58 +01:00
|
|
|
"fcnet_activation": "relu",
|
|
|
|
"post_fcnet_hiddens": [],
|
|
|
|
"post_fcnet_activation": None,
|
|
|
|
"custom_model": None, # Use this to define custom Q-model(s).
|
|
|
|
"custom_model_config": {},
|
2019-08-01 23:37:36 -07:00
|
|
|
},
|
2021-02-02 13:05:58 +01:00
|
|
|
# Model options for the policy function (see `Q_model` above for details).
|
|
|
|
# The difference to `Q_model` above is that no action concat'ing is
|
|
|
|
# performed before the post_fcnet stack.
|
2019-08-01 23:37:36 -07:00
|
|
|
"policy_model": {
|
2020-04-15 13:25:16 +02:00
|
|
|
"fcnet_hiddens": [256, 256],
|
2021-02-02 13:05:58 +01:00
|
|
|
"fcnet_activation": "relu",
|
|
|
|
"post_fcnet_hiddens": [],
|
|
|
|
"post_fcnet_activation": None,
|
|
|
|
"custom_model": None, # Use this to define a custom policy model.
|
|
|
|
"custom_model_config": {},
|
2019-08-01 23:37:36 -07:00
|
|
|
},
|
2020-02-24 01:10:20 +01:00
|
|
|
# Unsquash actions to the upper and lower bounds of env's action space.
|
2020-03-06 19:37:12 +01:00
|
|
|
# Ignored for discrete action spaces.
|
2019-12-20 10:51:25 -08:00
|
|
|
"normalize_actions": True,
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
# === Learning ===
|
2020-03-14 11:16:54 -07:00
|
|
|
# Disable setting done=True at end of episode. This should be set to True
|
|
|
|
# for infinite-horizon MDPs (e.g., many continuous control problems).
|
|
|
|
"no_done_at_end": False,
|
2020-03-04 12:58:34 -08:00
|
|
|
# Update the target by \tau * policy + (1-\tau) * target_policy.
|
2019-08-01 23:37:36 -07:00
|
|
|
"tau": 5e-3,
|
2020-03-06 19:37:12 +01:00
|
|
|
# Initial value to use for the entropy weight alpha.
|
|
|
|
"initial_alpha": 1.0,
|
|
|
|
# Target entropy lower bound. If "auto", will be set to -|A| (e.g. -2.0 for
|
|
|
|
# Discrete(2), -3.0 for Box(shape=(3,))).
|
|
|
|
# This is the inverse of reward scale, and will be optimized automatically.
|
2020-09-20 11:27:02 +02:00
|
|
|
"target_entropy": None,
|
|
|
|
# N-step target updates. If >1, sars' tuples in trajectories will be
|
|
|
|
# postprocessed to become sa[discounted sum of R][s t+n] tuples.
|
2019-08-01 23:37:36 -07:00
|
|
|
"n_step": 1,
|
2020-03-04 12:58:34 -08:00
|
|
|
# Number of env steps to optimize for before returning.
|
2019-08-23 02:21:11 -04:00
|
|
|
"timesteps_per_iteration": 100,
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
# === Replay buffer ===
|
2021-01-19 14:22:36 +01:00
|
|
|
# Size of the replay buffer (in time steps).
|
2019-08-01 23:37:36 -07:00
|
|
|
"buffer_size": int(1e6),
|
|
|
|
# If True prioritized replay buffer will be used.
|
|
|
|
"prioritized_replay": False,
|
|
|
|
"prioritized_replay_alpha": 0.6,
|
|
|
|
"prioritized_replay_beta": 0.4,
|
|
|
|
"prioritized_replay_eps": 1e-6,
|
2020-02-11 00:22:07 +01:00
|
|
|
"prioritized_replay_beta_annealing_timesteps": 20000,
|
2019-08-01 23:37:36 -07:00
|
|
|
"final_prioritized_replay_beta": 0.4,
|
2020-05-20 11:22:30 -07:00
|
|
|
# Whether to LZ4 compress observations
|
2019-08-01 23:37:36 -07:00
|
|
|
"compress_observations": False,
|
2020-07-06 05:07:27 +02:00
|
|
|
# If set, this will fix the ratio of replayed from a buffer and learned on
|
|
|
|
# timesteps to sampled from an environment and stored in the replay buffer
|
|
|
|
# timesteps. Otherwise, the replay will proceed at the native ratio
|
|
|
|
# determined by (train_batch_size / rollout_fragment_length).
|
2020-05-20 11:22:30 -07:00
|
|
|
"training_intensity": None,
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
# === Optimization ===
|
|
|
|
"optimization": {
|
|
|
|
"actor_learning_rate": 3e-4,
|
|
|
|
"critic_learning_rate": 3e-4,
|
|
|
|
"entropy_learning_rate": 3e-4,
|
|
|
|
},
|
2020-03-06 19:37:12 +01:00
|
|
|
# If not None, clip gradients during optimization at this value.
|
2020-04-15 13:25:16 +02:00
|
|
|
"grad_clip": None,
|
2019-08-01 23:37:36 -07:00
|
|
|
# How many steps of the model to sample before learning starts.
|
|
|
|
"learning_starts": 1500,
|
|
|
|
# Update the replay buffer with this many samples at once. Note that this
|
|
|
|
# setting applies per-worker if num_workers > 1.
|
2020-03-14 12:05:04 -07:00
|
|
|
"rollout_fragment_length": 1,
|
2021-01-19 14:22:36 +01:00
|
|
|
# Size of a batched sampled from replay buffer for training.
|
2019-08-01 23:37:36 -07:00
|
|
|
"train_batch_size": 256,
|
|
|
|
# Update the target network every `target_network_update_freq` steps.
|
|
|
|
"target_network_update_freq": 0,
|
|
|
|
|
|
|
|
# === Parallelism ===
|
|
|
|
# Whether to use a GPU for local optimization.
|
|
|
|
"num_gpus": 0,
|
|
|
|
# Number of workers for collecting samples with. This only makes sense
|
|
|
|
# to increase if your environment is particularly slow to sample, or if
|
|
|
|
# you"re using the Async or Ape-X optimizers.
|
|
|
|
"num_workers": 0,
|
|
|
|
# Whether to allocate GPUs for workers (if > 0).
|
|
|
|
"num_gpus_per_worker": 0,
|
|
|
|
# Whether to allocate CPUs for workers (if > 0).
|
|
|
|
"num_cpus_per_worker": 1,
|
|
|
|
# Whether to compute priorities on workers.
|
|
|
|
"worker_side_prioritization": False,
|
2020-01-31 18:54:12 +01:00
|
|
|
# Prevent iterations from going lower than this time span.
|
2019-08-01 23:37:36 -07:00
|
|
|
"min_iter_time_s": 1,
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
# Whether the loss should be calculated deterministically (w/o the
|
|
|
|
# stochastic action sampling step). True only useful for cont. actions and
|
|
|
|
# for debugging!
|
|
|
|
"_deterministic_loss": False,
|
|
|
|
# Use a Beta-distribution instead of a SquashedGaussian for bounded,
|
|
|
|
# continuous action spaces (not recommended, for debugging only).
|
|
|
|
"_use_beta_distribution": False,
|
2019-08-01 23:37:36 -07:00
|
|
|
})
|
|
|
|
# __sphinx_doc_end__
|
|
|
|
# yapf: enable
|
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def validate_config(config: TrainerConfigDict) -> None:
|
|
|
|
"""Validates the Trainer's config dict.
|
2020-04-15 13:25:16 +02:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
|
|
|
config (TrainerConfigDict): The Trainer's config to check.
|
2020-04-15 13:25:16 +02:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Raises:
|
|
|
|
ValueError: In case something is wrong with the config.
|
|
|
|
"""
|
2021-04-16 09:16:24 +02:00
|
|
|
if config["num_gpus"] > 1 and config["framework"] != "torch":
|
|
|
|
raise ValueError("`num_gpus` > 1 not yet supported for tf-SAC!")
|
2021-03-08 15:41:27 +01:00
|
|
|
|
2021-02-02 13:05:58 +01:00
|
|
|
if config["use_state_preprocessor"] != DEPRECATED_VALUE:
|
|
|
|
deprecation_warning(
|
|
|
|
old="config['use_state_preprocessor']", error=False)
|
|
|
|
config["use_state_preprocessor"] = DEPRECATED_VALUE
|
2020-06-25 19:01:32 +02:00
|
|
|
|
2020-07-08 16:12:20 +02:00
|
|
|
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
|
|
|
|
raise ValueError("`grad_clip` value must be > 0.0!")
|
|
|
|
|
2021-03-29 20:07:44 +02:00
|
|
|
if config["simple_optimizer"] != DEPRECATED_VALUE or \
|
|
|
|
config["simple_optimizer"] is False:
|
|
|
|
logger.warning("`simple_optimizer` must be True (or unset) for SAC!")
|
|
|
|
config["simple_optimizer"] = True
|
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
|
|
|
"""Policy class picker function. Class is chosen based on DL-framework.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config (TrainerConfigDict): The trainer's configuration dict.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Optional[Type[Policy]]: The Policy class to use with PPOTrainer.
|
|
|
|
If None, use `default_policy` provided in build_trainer().
|
|
|
|
"""
|
|
|
|
if config["framework"] == "torch":
|
|
|
|
from ray.rllib.agents.sac.sac_torch_policy import SACTorchPolicy
|
|
|
|
return SACTorchPolicy
|
|
|
|
|
|
|
|
|
|
|
|
# Build a child class of `Trainer` (based on the kwargs used to create the
|
|
|
|
# GenericOffPolicyTrainer class and the kwargs used in the call below), which
|
|
|
|
# uses the framework specific Policy determined in `get_policy_class()` above.
|
2019-08-01 23:37:36 -07:00
|
|
|
SACTrainer = GenericOffPolicyTrainer.with_updates(
|
2020-03-06 19:37:12 +01:00
|
|
|
name="SAC",
|
|
|
|
default_config=DEFAULT_CONFIG,
|
2020-09-20 11:27:02 +02:00
|
|
|
validate_config=validate_config,
|
2020-03-06 19:37:12 +01:00
|
|
|
default_policy=SACTFPolicy,
|
2020-04-15 13:25:16 +02:00
|
|
|
get_policy_class=get_policy_class,
|
|
|
|
)
|