[RLlib] SAC, RNNSAC, and CQL TrainerConfig objects (#25059)

This commit is contained in:
Steven Morad 2022-05-22 18:58:47 +01:00 committed by GitHub
parent 44773e810b
commit 501d932449
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
42 changed files with 662 additions and 411 deletions

View file

@ -438,7 +438,7 @@ Soft Actor Critic (SAC)
SAC architecture (same as DQN)
RLlib's soft-actor critic implementation is ported from the `official SAC repo <https://github.com/rail-berkeley/softlearning>`__ to better integrate with RLlib APIs.
Note that SAC has two fields to configure for custom models: ``policy_model`` and ``Q_model``, the ``model`` field of the config will be ignored.
Note that SAC has two fields to configure for custom models: ``policy_model_config`` and ``q_model_config``, the ``model`` field of the config will be ignored.
Tuned examples (continuous actions):
`Pendulum-v1 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/sac/pendulum-sac.yaml>`__,

View file

@ -106,7 +106,7 @@ This `runnable example <https://github.com/ray-project/ray/blob/master/rllib/exa
On-policy algorithms and experience postprocessing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RLlib assumes that input batches are of `postprocessed experiences <https://github.com/ray-project/ray/blob/cf21c634a390745ba6f8916b1f34f7b0453bc7dd/rllib/policy/policy.py#L376>`__. This isn't typically critical for off-policy algorithms (e.g., DQN's `post-processing <https://github.com/ray-project/ray/blob/cf21c634a390745ba6f8916b1f34f7b0453bc7dd/rllib/agents/dqn/dqn_tf_policy.py#L387>`__ is only needed if ``n_step > 1`` or ``worker_side_prioritization: True``). For off-policy algorithms, you can also safely set the ``postprocess_inputs: True`` config to auto-postprocess data.
RLlib assumes that input batches are of `postprocessed experiences <https://github.com/ray-project/ray/blob/cf21c634a390745ba6f8916b1f34f7b0453bc7dd/rllib/policy/policy.py#L376>`__. This isn't typically critical for off-policy algorithms (e.g., DQN's `post-processing <https://github.com/ray-project/ray/blob/cf21c634a390745ba6f8916b1f34f7b0453bc7dd/rllib/agents/dqn/dqn_tf_policy.py#L387>`__ is only needed if ``n_step > 1`` or ``replay_buffer_config.worker_side_prioritization: True``). For off-policy algorithms, you can also safely set the ``postprocess_inputs: True`` config to auto-postprocess data.
However, for on-policy algorithms like PPO, you'll need to pass in the extra values added during policy evaluation and postprocessing to ``batch_builder.add_values()``, e.g., ``logits``, ``vf_preds``, ``value_target``, and ``advantages`` for PPO. This is needed since the calculation of these values depends on the parameters of the *behaviour* policy, which RLlib does not have access to in the offline setting (in online training, these values are automatically added during policy evaluation).

View file

@ -13,10 +13,10 @@ cql-halfcheetahbulletenv-v0:
soft_horizon: False
horizon: 1000
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
tau: 0.005

View file

@ -10,10 +10,10 @@ sac-halfcheetahbulletenv-v0:
config:
horizon: 1000
soft_horizon: false
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
tau: 0.005

View file

@ -80,7 +80,7 @@ def _import_bc():
def _import_cql():
from ray.rllib.algorithms import cql
return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG
return cql.CQLTrainer, cql.DEFAULT_CONFIG
def _import_ddpg():

View file

@ -849,7 +849,7 @@ class TrainerConfig:
self.evaluation_num_workers = evaluation_num_workers
if custom_evaluation_function is not None:
self.custom_evaluation_function = custom_evaluation_function
if self.always_attach_evaluation_results:
if always_attach_evaluation_results:
self.always_attach_evaluation_results = always_attach_evaluation_results
return self

View file

@ -1,8 +1,9 @@
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
from ray.rllib.algorithms.cql.cql import CQLTrainer, DEFAULT_CONFIG, CQLConfig
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
__all__ = [
"CQL_DEFAULT_CONFIG",
"DEFAULT_CONFIG",
"CQLTorchPolicy",
"CQLTrainer",
"CQLConfig",
]

View file

@ -1,10 +1,13 @@
import logging
import numpy as np
from typing import Type
from typing import Optional, Type
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.algorithms.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.algorithms.sac.sac import (
SACTrainer,
SACConfig,
)
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
train_one_step,
@ -12,9 +15,12 @@ from ray.rllib.execution.train_ops import (
from ray.rllib.offline.shuffled_input import ShuffledInput
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
deprecation_warning,
Deprecated,
)
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
@ -31,38 +37,86 @@ tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
logger = logging.getLogger(__name__)
# fmt: off
# __sphinx_doc_begin__
CQL_DEFAULT_CONFIG = merge_dicts(
SAC_CONFIG, {
# You should override this to point to an offline dataset.
"input": "sampler",
# Switch off off-policy evaluation.
"input_evaluation": [],
# Number of iterations with Behavior Cloning Pretraining.
"bc_iters": 20000,
# CQL loss temperature.
"temperature": 1.0,
# Number of actions to sample for CQL loss.
"num_actions": 10,
# Whether to use the Lagrangian for Alpha Prime (in CQL loss).
"lagrangian": False,
# Lagrangian threshold.
"lagrangian_thresh": 5.0,
# Min Q weight multiplier.
"min_q_weight": 5.0,
# Reporting: As CQL is offline (no sampling steps), we need to limit
# `self.train()` reporting by the number of steps trained (not sampled).
"min_sample_timesteps_per_reporting": 0,
"min_train_timesteps_per_reporting": 100,
# Deprecated keys.
# Use `min_sample_timesteps_per_reporting` and
# `min_train_timesteps_per_reporting` instead.
"timesteps_per_iteration": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# fmt: on
class CQLConfig(SACConfig):
"""Defines a configuration class from which a CQLTrainer can be built.
Example:
>>> config = CQLConfig().training(gamma=0.9, lr=0.01)\
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
"""
def __init__(self, trainer_class=None):
super().__init__(trainer_class=trainer_class or CQLTrainer)
# fmt: off
# __sphinx_doc_begin__
# CQL-specific config settings:
self.bc_iters = 20000
self.temperature = 1.0
self.num_actions = 10
self.lagrangian = False
self.lagrangian_thresh = 5.0
self.min_q_weight = 5.0
# Changes to Trainer's/SACConfig's default:
# .offline_data()
self.input_evaluation = []
# .reporting()
self.min_sample_timesteps_per_reporting = 0
self.min_train_timesteps_per_reporting = 100
# fmt: on
# __sphinx_doc_end__
self.timesteps_per_iteration = DEPRECATED_VALUE
def training(
self,
*,
bc_iters: Optional[int] = None,
temperature: Optional[float] = None,
num_actions: Optional[int] = None,
lagrangian: Optional[bool] = None,
lagrangian_thresh: Optional[float] = None,
min_q_weight: Optional[float] = None,
**kwargs,
) -> "CQLConfig":
"""Sets the training-related configuration.
Args:
bc_iters: Number of iterations with Behavior Cloning pretraining.
temperature: CQL loss temperature.
num_actions: Number of actions to sample for CQL loss
lagrangian: Whether to use the Lagrangian for Alpha Prime (in CQL loss).
lagrangian_thresh: Lagrangian threshold.
min_q_weight: in Q weight multiplier.
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if bc_iters is not None:
self.bc_iters = bc_iters
if temperature is not None:
self.temperature = temperature
if num_actions is not None:
self.num_actions = num_actions
if lagrangian is not None:
self.lagrangian = lagrangian
if lagrangian_thresh is not None:
self.lagrangian_thresh = lagrangian_thresh
if min_q_weight is not None:
self.min_q_weight = min_q_weight
return self
class CQLTrainer(SACTrainer):
@ -111,7 +165,7 @@ class CQLTrainer(SACTrainer):
@classmethod
@override(SACTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return CQL_DEFAULT_CONFIG
return CQLConfig().to_dict()
@override(SACTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@ -208,3 +262,20 @@ class CQLTrainer(SACTrainer):
# Return all collected metrics for the iteration.
return train_results
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(CQLConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.cql.cql.DEFAULT_CONFIG",
new="ray.rllib.algorithms.cql.cql.CQLConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()
CQL_DEFAULT_CONFIG = DEFAULT_CONFIG

View file

@ -411,7 +411,7 @@ def apply_gradients_fn(policy, optimizer, grads_and_vars):
CQLTFPolicy = build_tf_policy(
name="CQLTFPolicy",
loss_fn=cql_loss,
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.algorithms.cql.cql.DEFAULT_CONFIG,
validate_spaces=validate_spaces,
stats_fn=cql_stats,
postprocess_fn=postprocess_trajectory,

View file

@ -390,7 +390,7 @@ CQLTorchPolicy = build_policy_class(
name="CQLTorchPolicy",
framework="torch",
loss_fn=cql_loss,
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.algorithms.cql.cql.DEFAULT_CONFIG,
stats_fn=cql_stats,
postprocess_fn=postprocess_trajectory,
extra_grad_process_fn=apply_grad_clipping,

View file

@ -4,7 +4,7 @@ import os
import unittest
import ray
import ray.rllib.algorithms.cql as cql
from ray.rllib.algorithms import cql
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import (
check_compute_single_action,
@ -38,36 +38,43 @@ class TestCQL(unittest.TestCase):
data_file = os.path.join(rllib_dir, "tests/data/pendulum/small.json")
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
config = {
"env": "Pendulum-v1",
"input": [data_file],
# In the files, we use here for testing, actions have already
# been normalized.
# This is usually the case when the file was generated by another
# RLlib algorithm (e.g. PPO or SAC).
"actions_in_input_normalized": False,
"clip_actions": True,
"train_batch_size": 2000,
"twin_q": True,
"replay_buffer_config": {"learning_starts": 0},
"bc_iters": 2, # 2 BC iters, 2 CQL iters.
"rollout_fragment_length": 1,
# Switch on off-policy evaluation.
"input_evaluation": ["is"],
"always_attach_evaluation_results": True,
"evaluation_interval": 2,
"evaluation_duration": 10,
"evaluation_config": {
"input": "sampler",
},
"evaluation_parallel_to_training": False,
"evaluation_num_workers": 2,
}
config = (
cql.CQLConfig()
.environment(
env="Pendulum-v1",
)
.offline_data(
input_=[data_file],
# In the files, we use here for testing, actions have already
# been normalized.
# This is usually the case when the file was generated by another
# RLlib algorithm (e.g. PPO or SAC).
actions_in_input_normalized=False,
# Switch on off-policy evaluation.
input_evaluation=["is"],
)
.training(
clip_actions=False,
train_batch_size=2000,
twin_q=True,
replay_buffer_config={"learning_starts": 0},
bc_iters=2,
)
.evaluation(
always_attach_evaluation_results=True,
evaluation_interval=2,
evaluation_duration=10,
evaluation_config={"input": "sampler"},
evaluation_parallel_to_training=False,
evaluation_num_workers=2,
)
.rollouts(rollout_fragment_length=1)
)
num_iterations = 4
# Test for tf/torch frameworks.
for fw in framework_iterator(config, with_eager_tracing=True):
trainer = cql.CQLTrainer(config=config)
trainer = config.build()
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)

View file

@ -124,9 +124,11 @@ class DDPGConfig(SimpleQConfig):
# .rollouts()
self.rollout_fragment_length = 1
self.worker_side_prioritization = False
self.compress_observations = False
# Deprecated.
self.worker_side_prioritization = DEPRECATED_VALUE
@override(TrainerConfig)
def training(
self,
@ -148,7 +150,6 @@ class DDPGConfig(SimpleQConfig):
use_huber: Optional[bool] = None,
huber_threshold: Optional[float] = None,
l2_reg: Optional[float] = None,
worker_side_prioritization: Optional[bool] = None,
training_intensity: Optional[float] = None,
**kwargs,
) -> "DDPGConfig":
@ -190,7 +191,6 @@ class DDPGConfig(SimpleQConfig):
use_huber: Conventionally, no need to clip gradients if using a huber loss
huber_threshold: Threshold of a huber loss
l2_reg: Weights for L2 regularization
worker_side_prioritization: Whether to compute priorities on workers.
training_intensity: The intensity with which to update the model
(vs collecting samples from
the env). If None, uses the "natural" value of:
@ -246,8 +246,6 @@ class DDPGConfig(SimpleQConfig):
self.huber_threshold = huber_threshold
if l2_reg is not None:
self.l2_reg = l2_reg
if worker_side_prioritization is not None:
self.worker_side_prioritization = worker_side_prioritization
if training_intensity is not None:
self.training_intensity = training_intensity

View file

@ -135,7 +135,7 @@ class DQNConfig(SimpleQConfig):
self.before_learn_on_batch = None
self.training_intensity = None
# Changes to SimpleQConfig default
# Changes to SimpleQConfig's default:
self.replay_buffer_config = {
"type": "MultiAgentPrioritizedReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
@ -176,7 +176,6 @@ class DQNConfig(SimpleQConfig):
Type[MultiAgentBatch],
] = None,
training_intensity: Optional[float] = None,
worker_side_prioritization: Optional[bool] = None,
replay_buffer_config: Optional[dict] = None,
**kwargs,
) -> "DQNConfig":
@ -213,7 +212,6 @@ class DQNConfig(SimpleQConfig):
-> will make sure that replay+train op will be executed 4x asoften as
rollout+insert op (4 * 250 = 1000).
See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
worker_side_prioritization: Whether to compute priorities on workers.
replay_buffer_config: Replay buffer config.
Examples:
{
@ -278,11 +276,11 @@ class DQNConfig(SimpleQConfig):
self.before_learn_on_batch = before_learn_on_batch
if training_intensity is not None:
self.training_intensity = training_intensity
if worker_side_prioritization is not None:
self.worker_side_priorizatiion = worker_side_prioritization
if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config
return self
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
class _deprecated_default_config(dict):

View file

@ -71,7 +71,6 @@ class QMixConfig(SimpleQConfig):
self.optim_alpha = 0.99
self.optim_eps = 0.00001
self.grad_norm_clipping = 10
self.worker_side_prioritization = False
# Override some of TrainerConfig's default values with QMix-specific values.
# .training()
@ -136,6 +135,8 @@ class QMixConfig(SimpleQConfig):
# __sphinx_doc_end__
# fmt: on
self.worker_side_prioritization = DEPRECATED_VALUE
@override(SimpleQConfig)
def training(
self,
@ -148,7 +149,6 @@ class QMixConfig(SimpleQConfig):
optim_alpha: Optional[float] = None,
optim_eps: Optional[float] = None,
grad_norm_clipping: Optional[float] = None,
worker_side_prioritization: Optional[bool] = None,
**kwargs,
) -> "QMixConfig":
"""Sets the training related configuration.
@ -164,8 +164,6 @@ class QMixConfig(SimpleQConfig):
optim_eps: RMSProp epsilon.
grad_norm_clipping: If not None, clip gradients during optimization at
this value.
worker_side_prioritization: Whether to compute priorities for the replay
buffer on worker side.
Returns:
This updated TrainerConfig object.
@ -189,8 +187,6 @@ class QMixConfig(SimpleQConfig):
self.optim_eps = optim_eps
if grad_norm_clipping is not None:
self.grad_norm_clipping = grad_norm_clipping
if worker_side_prioritization is not None:
self.worker_side_prioritization = worker_side_prioritization
return self

View file

@ -1,4 +1,4 @@
from ray.rllib.algorithms.sac.sac import SACTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.sac.sac import SACTrainer, DEFAULT_CONFIG, SACConfig
from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy
from ray.rllib.algorithms.sac.sac_torch_policy import SACTorchPolicy
@ -6,14 +6,16 @@ from ray.rllib.algorithms.sac.rnnsac import (
RNNSACTrainer,
DEFAULT_CONFIG as RNNSAC_DEFAULT_CONFIG,
)
from ray.rllib.algorithms.sac.rnnsac import RNNSACTorchPolicy
from ray.rllib.algorithms.sac.rnnsac import RNNSACTorchPolicy, RNNSACConfig
__all__ = [
"DEFAULT_CONFIG",
"SACTFPolicy",
"SACTorchPolicy",
"SACTrainer",
"SACConfig",
"RNNSAC_DEFAULT_CONFIG",
"RNNSACTorchPolicy",
"RNNSACTrainer",
"RNNSACConfig",
]

View file

@ -1,50 +1,78 @@
from typing import Type
from typing import Type, Optional
from ray.rllib.algorithms.sac import SACTrainer, DEFAULT_CONFIG as SAC_DEFAULT_CONFIG
from ray.rllib.algorithms.sac import (
SACTrainer,
SACConfig,
)
from ray.rllib.algorithms.sac.rnnsac_torch_policy import RNNSACTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
DEFAULT_CONFIG = SACTrainer.merge_trainer_configs(
SAC_DEFAULT_CONFIG,
{
# Batch mode (see common config)
"batch_mode": "complete_episodes",
# If True, assume a zero-initialized state input (no matter where in
# the episode the sequence is located).
# If False, store the initial states along with each SampleBatch, use
# it (as initial state when running through the network for training),
# and update that initial state during training (from the internal
# state outputs of the immediately preceding sequence).
"zero_init_states": True,
"replay_buffer_config": {
# If > 0, use the `burn_in` first steps of each replay-sampled sequence
# (starting either from all 0.0-values if `zero_init_state=True` or
# from the already stored values) to calculate an even more accurate
# initial states for the actual sequence (starting after this burn-in
# window). In the burn-in case, the actual length of the sequence
# used for loss calculation is `n - burn_in` time steps
# (n=LSTMs/attention nets max_seq_len).
"replay_burn_in": 0,
# Set automatically: The number of contiguous environment steps to
# replay at once. Will be calculated via
# model->max_seq_len + burn_in.
# Do not set this to any valid value!
"replay_sequence_length": -1,
},
"burn_in": DEPRECATED_VALUE,
},
_allow_unknown_configs=True,
)
class RNNSACConfig(SACConfig):
"""Defines a configuration class from which an RNNSACTrainer can be built.
Example:
>>> config = RNNSACConfig().training(gamma=0.9, lr=0.01)\
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
"""
def __init__(self, trainer_class=None):
super().__init__(trainer_class=trainer_class or RNNSACTrainer)
# fmt: off
# __sphinx_doc_begin__
self.burn_in = DEPRECATED_VALUE
self.batch_mode = "complete_episodes"
self.zero_init_states = True
self.replay_buffer_config["replay_burn_in"] = 0
# Set automatically: The number of contiguous environment steps to
# replay at once. Will be calculated via
# model->max_seq_len + burn_in.
# Do not set this to any valid value!
self.replay_buffer_config["replay_sequence_length"] = -1
# fmt: on
# __sphinx_doc_end__
@override(SACConfig)
def training(
self,
*,
zero_init_states: Optional[bool] = None,
**kwargs,
) -> "RNNSACConfig":
"""Sets the training related configuration.
Args:
zero_init_states: If True, assume a zero-initialized state input (no matter
where in the episode the sequence is located).
If False, store the initial states along with each SampleBatch, use
it (as initial state when running through the network for training),
and update that initial state during training (from the internal
state outputs of the immediately preceding sequence).
Returns:
This updated TrainerConfig object.
"""
super().training(**kwargs)
if zero_init_states is not None:
self.zero_init_states = zero_init_states
return self
class RNNSACTrainer(SACTrainer):
@classmethod
@override(SACTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
return RNNSACConfig().to_dict()
@override(SACTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@ -81,3 +109,19 @@ class RNNSACTrainer(SACTrainer):
@override(SACTrainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
return RNNSACTorchPolicy
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(RNNSACConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.sac.rnnsac.DEFAULT_CONFIG",
new="ray.rllib.algorithms.sac.rnnsac.RNNSACConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -45,12 +45,12 @@ def build_rnnsac_model(
num_outputs = int(np.product(obs_space.shape))
# Force-ignore any additionally provided hidden layer sizes.
# Everything should be configured using SAC's "Q_model" and "policy_model"
# settings.
# Everything should be configured using SAC's `q_model_config` and
# `policy_model_config` config settings.
policy_model_config = MODEL_DEFAULTS.copy()
policy_model_config.update(config["policy_model"])
policy_model_config.update(config["policy_model_config"])
q_model_config = MODEL_DEFAULTS.copy()
q_model_config.update(config["Q_model"])
q_model_config.update(config["q_model_config"])
default_model_cls = RNNSACTorchModel

View file

@ -1,12 +1,16 @@
import logging
from typing import Type
from typing import Type, Dict, Any, Optional, Union
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.algorithms.dqn.dqn import DQNTrainer
from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
deprecation_warning,
Deprecated,
)
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.typing import TrainerConfigDict
@ -15,160 +19,255 @@ tfp = try_import_tfp()
logger = logging.getLogger(__name__)
# fmt: off
# __sphinx_doc_begin__
# Adds the following updates to the (base) `Trainer` config in
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
DEFAULT_CONFIG = with_common_config({
# === Model ===
# Use two Q-networks (instead of one) for action-value estimation.
# Note: Each Q-network will have its own target network.
"twin_q": True,
# Use a e.g. conv2D state preprocessing network before concatenating the
# resulting (feature) vector with the action input for the input to
# the Q-networks.
"use_state_preprocessor": DEPRECATED_VALUE,
# Model options for the Q network(s). These will override MODEL_DEFAULTS.
# The `Q_model` dict is treated just as the top-level `model` dict in
# setting up the Q-network(s) (2 if twin_q=True).
# That means, you can do for different observation spaces:
# obs=Box(1D) -> Tuple(Box(1D) + Action) -> concat -> post_fcnet
# obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action
# -> post_fcnet
# obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action)
# -> vision-net -> concat w/ Box(1D) and action -> post_fcnet
# You can also have SAC use your custom_model as Q-model(s), by simply
# specifying the `custom_model` sub-key in below dict (just like you would
# do in the top-level `model` dict.
"Q_model": {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define custom Q-model(s).
"custom_model_config": {},
},
# Model options for the policy function (see `Q_model` above for details).
# The difference to `Q_model` above is that no action concat'ing is
# performed before the post_fcnet stack.
"policy_model": {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define a custom policy model.
"custom_model_config": {},
},
# Actions are already normalized, no need to clip them further.
"clip_actions": False,
class SACConfig(TrainerConfig):
"""Defines a configuration class from which an SACTrainer can be built.
# === Learning ===
# Update the target by \tau * policy + (1-\tau) * target_policy.
"tau": 5e-3,
# Initial value to use for the entropy weight alpha.
"initial_alpha": 1.0,
# Target entropy lower bound. If "auto", will be set to -|A| (e.g. -2.0 for
# Discrete(2), -3.0 for Box(shape=(3,))).
# This is the inverse of reward scale, and will be optimized automatically.
"target_entropy": "auto",
# N-step target updates. If >1, sars' tuples in trajectories will be
# postprocessed to become sa[discounted sum of R][s t+n] tuples.
"n_step": 1,
# Minimum env sampling timesteps to accumulate within a single `train()` call. This
# value does not affect learning, only the number of times `Trainer.step_attempt()`
# is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling
# timestep count has not been reached, will perform n more `step_attempt()` calls
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
"min_sample_timesteps_per_reporting": 100,
Example:
>>> config = SACConfig().training(gamma=0.9, lr=0.01)\
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
"""
# === Replay buffer ===
"replay_buffer_config": {
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": int(1e6),
# How many steps of the model to sample before learning starts.
"learning_starts": 1500,
# The number of continuous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
},
# Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well.
# Warnings will be created if:
# - This is True AND restoring from a checkpoint that contains no buffer
# data.
# - This is False AND restoring from a checkpoint that does contain
# buffer data.
"store_buffer_in_checkpoints": False,
# Whether to LZ4 compress observations
"compress_observations": False,
def __init__(self, trainer_class=None):
super().__init__(trainer_class=trainer_class or SACTrainer)
# fmt: off
# __sphinx_doc_begin__
# SAC-specific config settings.
self.twin_q = True
self.q_model_config = {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define custom Q-model(s).
"custom_model_config": {},
}
self.policy_model_config = {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define a custom policy model.
"custom_model_config": {},
}
self.clip_actions = False
self.tau = 5e-3
self.initial_alpha = 1.0
self.target_entropy = "auto"
self.n_step = 1
self.replay_buffer_config = {
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": int(1e6),
# How many steps of the model to sample before learning starts.
"learning_starts": 1500,
# If True prioritized replay buffer will be used.
"prioritized_replay": False,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
# Whether to compute priorities already on the remote worker side.
"worker_side_prioritization": False,
}
self.store_buffer_in_checkpoints = False
self.training_intensity = None
self.optimization = {
"actor_learning_rate": 3e-4,
"critic_learning_rate": 3e-4,
"entropy_learning_rate": 3e-4,
}
self.grad_clip = None
self.target_network_update_freq = 0
# The intensity with which to update the model (vs collecting samples from
# the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into and
# sampled from the buffer matches the given value.
# Example:
# training_intensity=1000.0
# train_batch_size=250 rollout_fragment_length=1
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 250 / 1 = 250.0
# -> will make sure that replay+train op will be executed 4x as
# often as rollout+insert op (4 * 250 = 1000).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
"training_intensity": None,
# .rollout()
self.rollout_fragment_length = 1
self.compress_observations = False
# === Optimization ===
"optimization": {
"actor_learning_rate": 3e-4,
"critic_learning_rate": 3e-4,
"entropy_learning_rate": 3e-4,
},
# If not None, clip gradients during optimization at this value.
"grad_clip": None,
# Update the replay buffer with this many samples at once. Note that this
# setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 1,
# Size of a batched sampled from replay buffer for training.
"train_batch_size": 256,
# Update the target network every `target_network_update_freq` sample steps.
"target_network_update_freq": 0,
# .training()
self.train_batch_size = 256
# === Parallelism ===
# Whether to use a GPU for local optimization.
"num_gpus": 0,
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Whether to allocate GPUs for workers (if > 0).
"num_gpus_per_worker": 0,
# Whether to allocate CPUs for workers (if > 0).
"num_cpus_per_worker": 1,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,
# .reporting()
self.min_time_s_per_reporting = 1
self.min_sample_timesteps_per_reporting = 100
# __sphinx_doc_end__
# fmt: on
# Whether the loss should be calculated deterministically (w/o the
# stochastic action sampling step). True only useful for cont. actions and
# for debugging!
"_deterministic_loss": False,
# Use a Beta-distribution instead of a SquashedGaussian for bounded,
# continuous action spaces (not recommended, for debugging only).
"_use_beta_distribution": False,
})
# __sphinx_doc_end__
# fmt: on
self._deterministic_loss = False
self._use_beta_distribution = False
self.use_state_preprocessor = DEPRECATED_VALUE
self.worker_side_prioritization = DEPRECATED_VALUE
@override(TrainerConfig)
def training(
self,
*,
twin_q: Optional[bool] = None,
q_model_config: Optional[Dict[str, Any]] = None,
policy_model_config: Optional[Dict[str, Any]] = None,
tau: Optional[float] = None,
initial_alpha: Optional[float] = None,
target_entropy: Optional[Union[str, float]] = None,
n_step: Optional[int] = None,
store_buffer_in_checkpoints: Optional[bool] = None,
replay_buffer_config: Optional[Dict[str, Any]] = None,
training_intensity: Optional[float] = None,
clip_actions: Optional[bool] = None,
grad_clip: Optional[float] = None,
optimization_config: Optional[Dict[str, Any]] = None,
target_network_update_freq: Optional[int] = None,
_deterministic_loss: Optional[bool] = None,
_use_beta_distribution: Optional[bool] = None,
**kwargs,
) -> "SACConfig":
"""Sets the training related configuration.
Args:
twin_q: Use two Q-networks (instead of one) for action-value estimation.
Note: Each Q-network will have its own target network.
q_model_config: Model configs for the Q network(s). These will override
MODEL_DEFAULTS. This is treated just as the top-level `model` dict in
setting up the Q-network(s) (2 if twin_q=True).
That means, you can do for different observation spaces:
obs=Box(1D) -> Tuple(Box(1D) + Action) -> concat -> post_fcnet
obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action
-> post_fcnet
obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action)
-> vision-net -> concat w/ Box(1D) and action -> post_fcnet
You can also have SAC use your custom_model as Q-model(s), by simply
specifying the `custom_model` sub-key in below dict (just like you would
do in the top-level `model` dict.
policy_model_config: Model options for the policy function (see
`q_model_config` above for details). The difference to `q_model_config`
above is that no action concat'ing is performed before the post_fcnet
stack.
tau: Update the target by \tau * policy + (1-\tau) * target_policy.
initial_alpha: Initial value to use for the entropy weight alpha.
target_entropy: Target entropy lower bound. If "auto", will be set
to -|A| (e.g. -2.0 for Discrete(2), -3.0 for Box(shape=(3,))).
This is the inverse of reward scale, and will be optimized
automatically.
n_step: N-step target updates. If >1, sars' tuples in trajectories will be
postprocessed to become sa[discounted sum of R][s t+n] tuples.
store_buffer_in_checkpoints: Set this to True, if you want the contents of
your buffer(s) to be stored in any saved checkpoints as well.
Warnings will be created if:
- This is True AND restoring from a checkpoint that contains no buffer
data.
- This is False AND restoring from a checkpoint that does contain
buffer data.
replay_buffer_config: Replay buffer config.
Examples:
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
"learning_starts": 1000,
"capacity": 50000,
"replay_batch_size": 32,
"replay_sequence_length": 1,
}
- OR -
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 50000,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
"replay_sequence_length": 1,
}
- Where -
prioritized_replay_alpha: Alpha parameter controls the degree of
prioritization in the buffer. In other words, when a buffer sample has
a higher temporal-difference error, with how much more probability
should it drawn to use to update the parametrized Q-network. 0.0
corresponds to uniform probability. Setting much above 1.0 may quickly
result as the sampling distribution could become heavily pointy with
low entropy.
prioritized_replay_beta: Beta parameter controls the degree of
importance sampling which suppresses the influence of gradient updates
from samples that have higher probability of being sampled via alpha
parameter and the temporal-difference error.
prioritized_replay_eps: Epsilon parameter sets the baseline probability
for sampling so that when the temporal-difference error of a sample is
zero, there is still a chance of drawing the sample.
training_intensity: The intensity with which to update the model (vs
collecting samples from the env).
If None, uses "natural" values of:
`train_batch_size` / (`rollout_fragment_length` x `num_workers` x
`num_envs_per_worker`).
If not None, will make sure that the ratio between timesteps inserted
into and sampled from th buffer matches the given values.
Example:
training_intensity=1000.0
train_batch_size=250
rollout_fragment_length=1
num_workers=1 (or 0)
num_envs_per_worker=1
-> natural value = 250 / 1 = 250.0
-> will make sure that replay+train op will be executed 4x asoften as
rollout+insert op (4 * 250 = 1000).
See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
clip_actions: Whether to clip actions. If actions are already normalized,
this should be set to False.
grad_clip: If not None, clip gradients during optimization at this value.
optimization_config: Config dict for optimization. Set the supported keys
`actor_learning_rate`, `critic_learning_rate`, and
`entropy_learning_rate` in here.
target_network_update_freq: Update the target network every
`target_network_update_freq` steps.
_deterministic_loss: Whether the loss should be calculated deterministically
(w/o the stochastic action sampling step). True only useful for
continuous actions and for debugging.
_use_beta_distribution: Use a Beta-distribution instead of a
`SquashedGaussian` for bounded, continuous action spaces (not
recommended; for debugging only).
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if twin_q is not None:
self.twin_q = twin_q
if q_model_config is not None:
self.q_model_config = q_model_config
if policy_model_config is not None:
self.policy_model_config = policy_model_config
if tau is not None:
self.tau = tau
if initial_alpha is not None:
self.initial_alpha = initial_alpha
if target_entropy is not None:
self.target_entropy = target_entropy
if n_step is not None:
self.n_step = n_step
if store_buffer_in_checkpoints is not None:
self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config
if training_intensity is not None:
self.training_intensity = training_intensity
if clip_actions is not None:
self.clip_actions = clip_actions
if grad_clip is not None:
self.grad_clip = grad_clip
if optimization_config is not None:
self.optimization_config = optimization_config
if target_network_update_freq is not None:
self.target_network_update_freq = target_network_update_freq
if _deterministic_loss is not None:
self._deterministic_loss = _deterministic_loss
if _use_beta_distribution is not None:
self._use_beta_distribution = _use_beta_distribution
return self
class SACTrainer(DQNTrainer):
@ -183,13 +282,13 @@ class SACTrainer(DQNTrainer):
"""
def __init__(self, *args, **kwargs):
self._allow_unknown_subkeys += ["policy_model", "Q_model"]
self._allow_unknown_subkeys += ["policy_model_config", "q_model_config"]
super().__init__(*args, **kwargs)
@classmethod
@override(DQNTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
return SACConfig().to_dict()
@override(DQNTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@ -200,6 +299,22 @@ class SACTrainer(DQNTrainer):
deprecation_warning(old="config['use_state_preprocessor']", error=False)
config["use_state_preprocessor"] = DEPRECATED_VALUE
if config.get("policy_model", DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning(
old="config['policy_model']",
new="config['policy_model_config']",
error=False,
)
config["policy_model_config"] = config["policy_model"]
if config.get("Q_model", DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning(
old="config['Q_model']",
new="config['q_model_config']",
error=False,
)
config["q_model_config"] = config["Q_model"]
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
raise ValueError("`grad_clip` value must be > 0.0!")
@ -220,3 +335,20 @@ class SACTrainer(DQNTrainer):
return SACTorchPolicy
else:
return SACTFPolicy
# Deprecated: Use ray.rllib.algorithms.sac.SACConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(SACConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.sac.sac.DEFAULT_CONFIG",
new="ray.rllib.algorithms.sac.sac.SACConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -19,9 +19,9 @@ class SACTFModel(TFModelV2):
To customize, do one of the following:
- sub-class SACTFModel and override one or more of its methods.
- Use SAC's `Q_model` and `policy_model` keys to tweak the default model
- Use SAC's `q_model_config` and `policy_model` keys to tweak the default model
behaviors (e.g. fcnet_hiddens, conv_filters, etc..).
- Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys
- Use SAC's `q_model_config->custom_model` and `policy_model->custom_model` keys
to specify your own custom Q-model(s) and policy-models, which will be
created within this SACTFModel (see `build_policy_model` and
`build_q_model`.
@ -160,7 +160,7 @@ class SACTFModel(TFModelV2):
Override this method in a sub-class of SACTFModel to implement your
own Q-nets. Alternatively, simply set `custom_model` within the
top level SAC `Q_model` config key to make this default implementation
top level SAC `q_model_config` config key to make this default implementation
of `build_q_model` use your custom Q-nets.
Returns:

View file

@ -72,12 +72,12 @@ def build_sac_model(
`policy.target_model`.
"""
# Force-ignore any additionally provided hidden layer sizes.
# Everything should be configured using SAC's "Q_model" and "policy_model"
# settings.
# Everything should be configured using SAC's `q_model_config` and
# `policy_model_config` config settings.
policy_model_config = copy.deepcopy(MODEL_DEFAULTS)
policy_model_config.update(config["policy_model"])
policy_model_config.update(config["policy_model_config"])
q_model_config = copy.deepcopy(MODEL_DEFAULTS)
q_model_config.update(config["Q_model"])
q_model_config.update(config["q_model_config"])
default_model_cls = SACTorchModel if config["framework"] == "torch" else SACTFModel

View file

@ -19,9 +19,9 @@ class SACTorchModel(TorchModelV2, nn.Module):
To customize, do one of the following:
- sub-class SACTorchModel and override one or more of its methods.
- Use SAC's `Q_model` and `policy_model` keys to tweak the default model
- Use SAC's `q_model_config` and `policy_model` keys to tweak the default model
behaviors (e.g. fcnet_hiddens, conv_filters, etc..).
- Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys
- Use SAC's `q_model_config->custom_model` and `policy_model->custom_model` keys
to specify your own custom Q-model(s) and policy-models, which will be
created within this SACTFModel (see `build_policy_model` and
`build_q_model`.
@ -168,7 +168,7 @@ class SACTorchModel(TorchModelV2, nn.Module):
Override this method in a sub-class of SACTFModel to implement your
own Q-nets. Alternatively, simply set `custom_model` within the
top level SAC `Q_model` config key to make this default implementation
top level SAC `q_model_config` config key to make this default implementation
of `build_q_model` use your custom Q-nets.
Returns:

View file

@ -1,7 +1,7 @@
import unittest
import ray
import ray.rllib.algorithms.sac as sac
from ray.rllib.algorithms import sac
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check_compute_single_action, framework_iterator
@ -20,42 +20,39 @@ class TestRNNSAC(unittest.TestCase):
def test_rnnsac_compilation(self):
"""Test whether a R2D2Trainer can be built on all frameworks."""
config = sac.RNNSAC_DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
# Wrap with an LSTM and use a very simple base-model.
config["model"] = {
"max_seq_len": 20,
}
config["policy_model"] = {
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [10],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
}
config["Q_model"] = {
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [10],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
}
# Test with MultiAgentPrioritizedReplayBuffer
config["replay_buffer_config"] = {
"type": "MultiAgentPrioritizedReplayBuffer",
"replay_burn_in": 20,
"zero_init_states": True,
}
config["lr"] = 5e-4
config = (
sac.RNNSACConfig()
.rollouts(num_rollout_workers=0)
.training(
# Wrap with an LSTM and use a very simple base-model.
model={"max_seq_len": 20},
policy_model_config={
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [10],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
q_model_config={
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [10],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
"replay_burn_in": 20,
"zero_init_states": True,
},
lr=5e-4,
)
)
num_iterations = 1
# Test building an RNNSAC agent in all frameworks.
for _ in framework_iterator(config, frameworks="torch"):
trainer = sac.RNNSACTrainer(config=config, env="CartPole-v0")
trainer = config.build(env="CartPole-v0")
for i in range(num_iterations):
results = trainer.train()
print(results)

View file

@ -6,7 +6,7 @@ import re
import unittest
import ray
import ray.rllib.algorithms.sac as sac
from ray.rllib.algorithms import sac
from ray.rllib.algorithms.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss
from ray.rllib.algorithms.sac.sac_torch_policy import actor_critic_loss as loss_torch
from ray.rllib.examples.env.random_env import RandomEnv
@ -74,20 +74,17 @@ class TestSAC(unittest.TestCase):
def test_sac_compilation(self):
"""Tests whether an SACTrainer can be built with all frameworks."""
config = sac.DEFAULT_CONFIG.copy()
config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy()
config["num_workers"] = 0 # Run locally.
config["n_step"] = 3
config["twin_q"] = True
config["replay_buffer_config"]["learning_starts"] = 0
config["rollout_fragment_length"] = 10
config["train_batch_size"] = 10
# If we use default buffer size (1e6), the buffer will take up
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
# available system memory (8.34816 GB).
config["replay_buffer_config"]["capacity"] = 40000
# Test with saved replay buffer.
config["store_buffer_in_checkpoints"] = True
config = (
sac.SACConfig()
.training(
n_step=3,
twin_q=True,
replay_buffer_config={"learning_starts": 0, "capacity": 40000},
store_buffer_in_checkpoints=True,
train_batch_size=10,
)
.rollouts(num_rollout_workers=0, rollout_fragment_length=10)
)
num_iterations = 1
ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
@ -134,12 +131,12 @@ class TestSAC(unittest.TestCase):
print("Env={}".format(env))
# Test making the Q-model a custom one for CartPole, otherwise,
# use the default model.
config["Q_model"]["custom_model"] = (
config.q_model_config["custom_model"] = (
"batch_norm{}".format("_torch" if fw == "torch" else "")
if env == "CartPole-v0"
else None
)
trainer = sac.SACTrainer(config=config, env=env)
trainer = config.build(env=env)
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
@ -167,22 +164,25 @@ class TestSAC(unittest.TestCase):
def test_sac_loss_function(self):
"""Tests SAC loss function results across all frameworks."""
config = sac.DEFAULT_CONFIG.copy()
# Run locally.
config["seed"] = 42
config["num_workers"] = 0
config["replay_buffer_config"]["learning_starts"] = 0
config["twin_q"] = False
config["gamma"] = 0.99
# Switch on deterministic loss so we can compare the loss values.
config["_deterministic_loss"] = True
# Use very simple nets.
config["Q_model"]["fcnet_hiddens"] = [10]
config["policy_model"]["fcnet_hiddens"] = [10]
# Make sure, timing differences do not affect trainer.train().
config["min_time_s_per_reporting"] = 0
# Test SAC with Simplex action space.
config["env_config"] = {"simplex_actions": True}
config = (
sac.SACConfig()
.training(
twin_q=False,
gamma=0.99,
_deterministic_loss=True,
q_model_config={"fcnet_hiddens": [10]},
policy_model_config={"fcnet_hiddens": [10]},
replay_buffer_config={"learning_starts": 0},
)
.rollouts(num_rollout_workers=0)
.reporting(
min_time_s_per_reporting=0,
)
.environment(
env_config={"simplex_actions": True},
)
.debugging(seed=42)
)
map_ = {
# Action net.
@ -247,7 +247,7 @@ class TestSAC(unittest.TestCase):
config, frameworks=("tf", "torch"), session=True
):
# Generate Trainer and get its default Policy object.
trainer = sac.SACTrainer(config=config, env=env)
trainer = config.build(env=env)
policy = trainer.get_policy()
p_sess = None
if sess:
@ -287,7 +287,7 @@ class TestSAC(unittest.TestCase):
sorted(weights_dict.keys()),
log_alpha,
fw,
gamma=config["gamma"],
gamma=config.gamma,
sess=sess,
)
@ -520,19 +520,22 @@ class TestSAC(unittest.TestCase):
return dict_samples[self.steps], 1, self.steps >= 5, {}
tune.register_env("nested", lambda _: NestedDictEnv())
config = sac.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
config["replay_buffer_config"]["learning_starts"] = 0
config["rollout_fragment_length"] = 5
config["train_batch_size"] = 5
config["replay_buffer_config"]["capacity"] = 10
# Disable preprocessors.
config["_disable_preprocessor_api"] = True
config = (
sac.SACConfig()
.training(
replay_buffer_config={"learning_starts": 0, "capacity": 10},
train_batch_size=5,
)
.rollouts(
num_rollout_workers=0,
rollout_fragment_length=5,
)
.experimental(_disable_preprocessor_api=True)
)
num_iterations = 1
for _ in framework_iterator(config, with_eager_tracing=True):
trainer = sac.SACTrainer(env="nested", config=config)
trainer = config.build(env="nested")
for _ in range(num_iterations):
results = trainer.train()
check_train_results(results)

View file

@ -139,7 +139,6 @@ class SlateQConfig(TrainerConfig):
rmsprop_epsilon: Optional[float] = None,
grad_clip: Optional[float] = None,
n_step: Optional[int] = None,
worker_side_prioritization: Optional[bool] = None,
**kwargs,
) -> "SlateQConfig":
"""Sets the training related configuration.
@ -172,8 +171,6 @@ class SlateQConfig(TrainerConfig):
rmsprop_epsilon: RMSProp epsilon hyperparameter.
grad_clip: If not None, clip gradients during optimization at this value.
n_step: N-step parameter for Q-learning.
worker_side_prioritization: Whether to compute priorities for the replay
buffer on the workers.
Returns:
This updated TrainerConfig object.
@ -205,8 +202,6 @@ class SlateQConfig(TrainerConfig):
self.grad_clip = grad_clip
if n_step is not None:
self.n_step = n_step
if worker_side_prioritization is not None:
self.worker_side_prioritization = worker_side_prioritization
return self

View file

@ -29,7 +29,7 @@ if __name__ == "__main__":
# See rllib/tuned_examples/cql/pendulum-cql.yaml for comparison.
config = cql.CQL_DEFAULT_CONFIG.copy()
config = cql.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
config["horizon"] = 200
config["soft_horizon"] = True
@ -45,11 +45,11 @@ if __name__ == "__main__":
config["replay_buffer_config"]["capacity"] = int(1e6)
config["tau"] = 0.005
config["target_entropy"] = "auto"
config["Q_model"] = {
config["q_model_config"] = {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
}
config["policy_model"] = {
config["policy_model_config"] = {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
}

View file

@ -55,14 +55,14 @@ config = {
"model": {
"max_seq_len": 20,
},
"policy_model": {
"policy_model_config": {
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [64, 64],
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
"Q_model": {
"q_model_config": {
"use_lstm": True,
"lstm_cell_size": 64,
"fcnet_hiddens": [64, 64],

View file

@ -177,8 +177,8 @@ class TestComputeLogLikelihood(unittest.TestCase):
"""Tests SAC's (cont. actions) compute_log_likelihoods method."""
config = sac.DEFAULT_CONFIG.copy()
config["seed"] = 42
config["policy_model"]["fcnet_hiddens"] = [10]
config["policy_model"]["fcnet_activation"] = "linear"
config["policy_model_config"]["fcnet_hiddens"] = [10]
config["policy_model_config"]["fcnet_activation"] = "linear"
prev_a = np.array([0.0])
# SAC cont uses a squashed normal distribution. Implement it's logp
@ -210,8 +210,8 @@ class TestComputeLogLikelihood(unittest.TestCase):
"""Tests SAC's (discrete actions) compute_log_likelihoods method."""
config = sac.DEFAULT_CONFIG.copy()
config["seed"] = 42
config["policy_model"]["fcnet_hiddens"] = [10]
config["policy_model"]["fcnet_activation"] = "linear"
config["policy_model_config"]["fcnet_hiddens"] = [10]
config["policy_model_config"]["fcnet_activation"] = "linear"
prev_a = np.array(0)
do_test_log_likelihood(sac.SACTrainer, config, prev_a)

View file

@ -15,10 +15,10 @@ halfcheetah_bc:
framework: torch
soft_horizon: False
horizon: 1000
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
tau: 0.005

View file

@ -17,10 +17,10 @@ halfcheetah_cql:
framework: tf
soft_horizon: False
horizon: 1000
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
tau: 0.005

View file

@ -15,10 +15,10 @@ hopper_bc:
framework: torch
soft_horizon: False
horizon: 1000
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
tau: 0.005

View file

@ -15,10 +15,10 @@ hopper_cql:
framework: torch
soft_horizon: False
horizon: 1000
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
tau: 0.005

View file

@ -31,7 +31,7 @@ ddpg-hopperbulletenv-v0:
prioritized_replay_alpha: 0.6
prioritized_replay_beta: 0.4
prioritized_replay_eps: 0.000001
worker_side_prioritization: False
worker_side_prioritization: false
learning_starts: 500
clip_rewards: False
actor_lr: 0.001

View file

@ -37,7 +37,7 @@ mountaincarcontinuous-ddpg:
prioritized_replay_alpha: 0.6
prioritized_replay_beta: 0.4
prioritized_replay_eps: 0.000001
worker_side_prioritization: False
worker_side_prioritization: false
clip_rewards: False
# === Optimization ===

View file

@ -36,7 +36,7 @@ pendulum-ddpg:
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
capacity: 10000
worker_side_prioritization: False
worker_side_prioritization: false
clip_rewards: False
# === Optimization ===

View file

@ -14,10 +14,10 @@ atari-sac-tf-and-torch:
framework:
grid_search: [tf, torch]
gamma: 0.99
Q_model:
q_model_config:
hidden_activation: relu
hidden_layer_sizes: [512]
policy_model:
policy_model_config:
hidden_activation: relu
hidden_layer_sizes: [512]
# Do hard syncs.

View file

@ -8,10 +8,10 @@ halfcheetah-pybullet-sac:
framework: tf
horizon: 1000
soft_horizon: false
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
tau: 0.005

View file

@ -9,10 +9,10 @@ halfcheetah_sac:
framework: tf
horizon: 1000
soft_horizon: false
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
tau: 0.005

View file

@ -11,10 +11,10 @@ mspacman-sac-tf:
# Works for both torch and tf.
framework: tf
gamma: 0.99
Q_model:
q_model_config:
fcnet_hiddens: [512]
fcnet_activation: relu
policy_model:
policy_model_config:
fcnet_hiddens: [512]
fcnet_activation: relu
# Do hard syncs.

View file

@ -10,10 +10,10 @@ pendulum-sac-fake-gpus:
framework: tf
horizon: 200
soft_horizon: false
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [ 256, 256 ]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [ 256, 256 ]
tau: 0.005

View file

@ -12,10 +12,10 @@ pendulum-sac:
framework: tf
horizon: 200
soft_horizon: false
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
tau: 0.005

View file

@ -19,10 +19,10 @@ transformed-actions-pendulum-sac-dummy-torch:
horizon: 200
soft_horizon: false
Q_model:
q_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
policy_model:
policy_model_config:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
tau: 0.005

View file

@ -138,6 +138,13 @@ def validate_buffer_config(config: dict):
if config.get("replay_buffer_config", None) is None:
config["replay_buffer_config"] = {}
if config.get("worker_side_prioritization", DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning(
old="config['worker_side_prioritization']",
new="config['replay_buffer_config']['worker_side_prioritization']",
error=True,
)
prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE)
if prioritized_replay != DEPRECATED_VALUE:
deprecation_warning(