diff --git a/doc/source/rllib/rllib-algorithms.rst b/doc/source/rllib/rllib-algorithms.rst index 031df9ee8..89b18e4cf 100644 --- a/doc/source/rllib/rllib-algorithms.rst +++ b/doc/source/rllib/rllib-algorithms.rst @@ -438,7 +438,7 @@ Soft Actor Critic (SAC) SAC architecture (same as DQN) RLlib's soft-actor critic implementation is ported from the `official SAC repo `__ to better integrate with RLlib APIs. -Note that SAC has two fields to configure for custom models: ``policy_model`` and ``Q_model``, the ``model`` field of the config will be ignored. +Note that SAC has two fields to configure for custom models: ``policy_model_config`` and ``q_model_config``, the ``model`` field of the config will be ignored. Tuned examples (continuous actions): `Pendulum-v1 `__, diff --git a/doc/source/rllib/rllib-offline.rst b/doc/source/rllib/rllib-offline.rst index ebc1fbe55..b9fe003ae 100644 --- a/doc/source/rllib/rllib-offline.rst +++ b/doc/source/rllib/rllib-offline.rst @@ -106,7 +106,7 @@ This `runnable example `__. This isn't typically critical for off-policy algorithms (e.g., DQN's `post-processing `__ is only needed if ``n_step > 1`` or ``worker_side_prioritization: True``). For off-policy algorithms, you can also safely set the ``postprocess_inputs: True`` config to auto-postprocess data. +RLlib assumes that input batches are of `postprocessed experiences `__. This isn't typically critical for off-policy algorithms (e.g., DQN's `post-processing `__ is only needed if ``n_step > 1`` or ``replay_buffer_config.worker_side_prioritization: True``). For off-policy algorithms, you can also safely set the ``postprocess_inputs: True`` config to auto-postprocess data. However, for on-policy algorithms like PPO, you'll need to pass in the extra values added during policy evaluation and postprocessing to ``batch_builder.add_values()``, e.g., ``logits``, ``vf_preds``, ``value_target``, and ``advantages`` for PPO. This is needed since the calculation of these values depends on the parameters of the *behaviour* policy, which RLlib does not have access to in the offline setting (in online training, these values are automatically added during policy evaluation). diff --git a/release/rllib_tests/learning_tests/yaml_files/cql-halfcheetahbulletenv-v0.yaml b/release/rllib_tests/learning_tests/yaml_files/cql-halfcheetahbulletenv-v0.yaml index 2eecc82eb..af2275bd5 100644 --- a/release/rllib_tests/learning_tests/yaml_files/cql-halfcheetahbulletenv-v0.yaml +++ b/release/rllib_tests/learning_tests/yaml_files/cql-halfcheetahbulletenv-v0.yaml @@ -13,10 +13,10 @@ cql-halfcheetahbulletenv-v0: soft_horizon: False horizon: 1000 - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256, 256] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256, 256] tau: 0.005 diff --git a/release/rllib_tests/learning_tests/yaml_files/sac-halfcheetahbulletenv-v0.yaml b/release/rllib_tests/learning_tests/yaml_files/sac-halfcheetahbulletenv-v0.yaml index f7df0b8d7..6c821fe5c 100644 --- a/release/rllib_tests/learning_tests/yaml_files/sac-halfcheetahbulletenv-v0.yaml +++ b/release/rllib_tests/learning_tests/yaml_files/sac-halfcheetahbulletenv-v0.yaml @@ -10,10 +10,10 @@ sac-halfcheetahbulletenv-v0: config: horizon: 1000 soft_horizon: false - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256] tau: 0.005 diff --git a/rllib/agents/registry.py b/rllib/agents/registry.py index 61aefccdc..81cadb9d8 100644 --- a/rllib/agents/registry.py +++ b/rllib/agents/registry.py @@ -80,7 +80,7 @@ def _import_bc(): def _import_cql(): from ray.rllib.algorithms import cql - return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG + return cql.CQLTrainer, cql.DEFAULT_CONFIG def _import_ddpg(): diff --git a/rllib/agents/trainer_config.py b/rllib/agents/trainer_config.py index cb3c0f950..bf8f983d4 100644 --- a/rllib/agents/trainer_config.py +++ b/rllib/agents/trainer_config.py @@ -849,7 +849,7 @@ class TrainerConfig: self.evaluation_num_workers = evaluation_num_workers if custom_evaluation_function is not None: self.custom_evaluation_function = custom_evaluation_function - if self.always_attach_evaluation_results: + if always_attach_evaluation_results: self.always_attach_evaluation_results = always_attach_evaluation_results return self diff --git a/rllib/algorithms/cql/__init__.py b/rllib/algorithms/cql/__init__.py index 56a90d192..b99b27031 100644 --- a/rllib/algorithms/cql/__init__.py +++ b/rllib/algorithms/cql/__init__.py @@ -1,8 +1,9 @@ -from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG +from ray.rllib.algorithms.cql.cql import CQLTrainer, DEFAULT_CONFIG, CQLConfig from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy __all__ = [ - "CQL_DEFAULT_CONFIG", + "DEFAULT_CONFIG", "CQLTorchPolicy", "CQLTrainer", + "CQLConfig", ] diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index 556089f75..827854bec 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -1,10 +1,13 @@ import logging import numpy as np -from typing import Type +from typing import Optional, Type from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy -from ray.rllib.algorithms.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG +from ray.rllib.algorithms.sac.sac import ( + SACTrainer, + SACConfig, +) from ray.rllib.execution.train_ops import ( multi_gpu_train_one_step, train_one_step, @@ -12,9 +15,12 @@ from ray.rllib.execution.train_ops import ( from ray.rllib.offline.shuffled_input import ShuffledInput from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override -from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.utils.deprecation import ( + DEPRECATED_VALUE, + deprecation_warning, + Deprecated, +) from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.metrics import ( LAST_TARGET_UPDATE_TS, @@ -31,38 +37,86 @@ tf1, tf, tfv = try_import_tf() tfp = try_import_tfp() logger = logging.getLogger(__name__) -# fmt: off -# __sphinx_doc_begin__ -CQL_DEFAULT_CONFIG = merge_dicts( - SAC_CONFIG, { - # You should override this to point to an offline dataset. - "input": "sampler", - # Switch off off-policy evaluation. - "input_evaluation": [], - # Number of iterations with Behavior Cloning Pretraining. - "bc_iters": 20000, - # CQL loss temperature. - "temperature": 1.0, - # Number of actions to sample for CQL loss. - "num_actions": 10, - # Whether to use the Lagrangian for Alpha Prime (in CQL loss). - "lagrangian": False, - # Lagrangian threshold. - "lagrangian_thresh": 5.0, - # Min Q weight multiplier. - "min_q_weight": 5.0, - # Reporting: As CQL is offline (no sampling steps), we need to limit - # `self.train()` reporting by the number of steps trained (not sampled). - "min_sample_timesteps_per_reporting": 0, - "min_train_timesteps_per_reporting": 100, - # Deprecated keys. - # Use `min_sample_timesteps_per_reporting` and - # `min_train_timesteps_per_reporting` instead. - "timesteps_per_iteration": DEPRECATED_VALUE, - }) -# __sphinx_doc_end__ -# fmt: on +class CQLConfig(SACConfig): + """Defines a configuration class from which a CQLTrainer can be built. + + Example: + >>> config = CQLConfig().training(gamma=0.9, lr=0.01)\ + ... .resources(num_gpus=0)\ + ... .rollouts(num_rollout_workers=4) + >>> print(config.to_dict()) + >>> # Build a Trainer object from the config and run 1 training iteration. + >>> trainer = config.build(env="CartPole-v1") + >>> trainer.train() + """ + + def __init__(self, trainer_class=None): + super().__init__(trainer_class=trainer_class or CQLTrainer) + + # fmt: off + # __sphinx_doc_begin__ + # CQL-specific config settings: + self.bc_iters = 20000 + self.temperature = 1.0 + self.num_actions = 10 + self.lagrangian = False + self.lagrangian_thresh = 5.0 + self.min_q_weight = 5.0 + + # Changes to Trainer's/SACConfig's default: + # .offline_data() + self.input_evaluation = [] + + # .reporting() + self.min_sample_timesteps_per_reporting = 0 + self.min_train_timesteps_per_reporting = 100 + # fmt: on + # __sphinx_doc_end__ + + self.timesteps_per_iteration = DEPRECATED_VALUE + + def training( + self, + *, + bc_iters: Optional[int] = None, + temperature: Optional[float] = None, + num_actions: Optional[int] = None, + lagrangian: Optional[bool] = None, + lagrangian_thresh: Optional[float] = None, + min_q_weight: Optional[float] = None, + **kwargs, + ) -> "CQLConfig": + """Sets the training-related configuration. + + Args: + bc_iters: Number of iterations with Behavior Cloning pretraining. + temperature: CQL loss temperature. + num_actions: Number of actions to sample for CQL loss + lagrangian: Whether to use the Lagrangian for Alpha Prime (in CQL loss). + lagrangian_thresh: Lagrangian threshold. + min_q_weight: in Q weight multiplier. + + Returns: + This updated TrainerConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if bc_iters is not None: + self.bc_iters = bc_iters + if temperature is not None: + self.temperature = temperature + if num_actions is not None: + self.num_actions = num_actions + if lagrangian is not None: + self.lagrangian = lagrangian + if lagrangian_thresh is not None: + self.lagrangian_thresh = lagrangian_thresh + if min_q_weight is not None: + self.min_q_weight = min_q_weight + + return self class CQLTrainer(SACTrainer): @@ -111,7 +165,7 @@ class CQLTrainer(SACTrainer): @classmethod @override(SACTrainer) def get_default_config(cls) -> TrainerConfigDict: - return CQL_DEFAULT_CONFIG + return CQLConfig().to_dict() @override(SACTrainer) def validate_config(self, config: TrainerConfigDict) -> None: @@ -208,3 +262,20 @@ class CQLTrainer(SACTrainer): # Return all collected metrics for the iteration. return train_results + + +class _deprecated_default_config(dict): + def __init__(self): + super().__init__(CQLConfig().to_dict()) + + @Deprecated( + old="ray.rllib.algorithms.cql.cql.DEFAULT_CONFIG", + new="ray.rllib.algorithms.cql.cql.CQLConfig(...)", + error=False, + ) + def __getitem__(self, item): + return super().__getitem__(item) + + +DEFAULT_CONFIG = _deprecated_default_config() +CQL_DEFAULT_CONFIG = DEFAULT_CONFIG diff --git a/rllib/algorithms/cql/cql_tf_policy.py b/rllib/algorithms/cql/cql_tf_policy.py index 8a51927c3..4fc617a6c 100644 --- a/rllib/algorithms/cql/cql_tf_policy.py +++ b/rllib/algorithms/cql/cql_tf_policy.py @@ -411,7 +411,7 @@ def apply_gradients_fn(policy, optimizer, grads_and_vars): CQLTFPolicy = build_tf_policy( name="CQLTFPolicy", loss_fn=cql_loss, - get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG, + get_default_config=lambda: ray.rllib.algorithms.cql.cql.DEFAULT_CONFIG, validate_spaces=validate_spaces, stats_fn=cql_stats, postprocess_fn=postprocess_trajectory, diff --git a/rllib/algorithms/cql/cql_torch_policy.py b/rllib/algorithms/cql/cql_torch_policy.py index a7d26eb2c..9e44f92eb 100644 --- a/rllib/algorithms/cql/cql_torch_policy.py +++ b/rllib/algorithms/cql/cql_torch_policy.py @@ -390,7 +390,7 @@ CQLTorchPolicy = build_policy_class( name="CQLTorchPolicy", framework="torch", loss_fn=cql_loss, - get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG, + get_default_config=lambda: ray.rllib.algorithms.cql.cql.DEFAULT_CONFIG, stats_fn=cql_stats, postprocess_fn=postprocess_trajectory, extra_grad_process_fn=apply_grad_clipping, diff --git a/rllib/algorithms/cql/tests/test_cql.py b/rllib/algorithms/cql/tests/test_cql.py index dfbaac4cd..aed354d06 100644 --- a/rllib/algorithms/cql/tests/test_cql.py +++ b/rllib/algorithms/cql/tests/test_cql.py @@ -4,7 +4,7 @@ import os import unittest import ray -import ray.rllib.algorithms.cql as cql +from ray.rllib.algorithms import cql from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import ( check_compute_single_action, @@ -38,36 +38,43 @@ class TestCQL(unittest.TestCase): data_file = os.path.join(rllib_dir, "tests/data/pendulum/small.json") print("data_file={} exists={}".format(data_file, os.path.isfile(data_file))) - config = { - "env": "Pendulum-v1", - "input": [data_file], - # In the files, we use here for testing, actions have already - # been normalized. - # This is usually the case when the file was generated by another - # RLlib algorithm (e.g. PPO or SAC). - "actions_in_input_normalized": False, - "clip_actions": True, - "train_batch_size": 2000, - "twin_q": True, - "replay_buffer_config": {"learning_starts": 0}, - "bc_iters": 2, # 2 BC iters, 2 CQL iters. - "rollout_fragment_length": 1, - # Switch on off-policy evaluation. - "input_evaluation": ["is"], - "always_attach_evaluation_results": True, - "evaluation_interval": 2, - "evaluation_duration": 10, - "evaluation_config": { - "input": "sampler", - }, - "evaluation_parallel_to_training": False, - "evaluation_num_workers": 2, - } + config = ( + cql.CQLConfig() + .environment( + env="Pendulum-v1", + ) + .offline_data( + input_=[data_file], + # In the files, we use here for testing, actions have already + # been normalized. + # This is usually the case when the file was generated by another + # RLlib algorithm (e.g. PPO or SAC). + actions_in_input_normalized=False, + # Switch on off-policy evaluation. + input_evaluation=["is"], + ) + .training( + clip_actions=False, + train_batch_size=2000, + twin_q=True, + replay_buffer_config={"learning_starts": 0}, + bc_iters=2, + ) + .evaluation( + always_attach_evaluation_results=True, + evaluation_interval=2, + evaluation_duration=10, + evaluation_config={"input": "sampler"}, + evaluation_parallel_to_training=False, + evaluation_num_workers=2, + ) + .rollouts(rollout_fragment_length=1) + ) num_iterations = 4 # Test for tf/torch frameworks. for fw in framework_iterator(config, with_eager_tracing=True): - trainer = cql.CQLTrainer(config=config) + trainer = config.build() for i in range(num_iterations): results = trainer.train() check_train_results(results) diff --git a/rllib/algorithms/ddpg/ddpg.py b/rllib/algorithms/ddpg/ddpg.py index 3456848bd..caf55902f 100644 --- a/rllib/algorithms/ddpg/ddpg.py +++ b/rllib/algorithms/ddpg/ddpg.py @@ -124,9 +124,11 @@ class DDPGConfig(SimpleQConfig): # .rollouts() self.rollout_fragment_length = 1 - self.worker_side_prioritization = False self.compress_observations = False + # Deprecated. + self.worker_side_prioritization = DEPRECATED_VALUE + @override(TrainerConfig) def training( self, @@ -148,7 +150,6 @@ class DDPGConfig(SimpleQConfig): use_huber: Optional[bool] = None, huber_threshold: Optional[float] = None, l2_reg: Optional[float] = None, - worker_side_prioritization: Optional[bool] = None, training_intensity: Optional[float] = None, **kwargs, ) -> "DDPGConfig": @@ -190,7 +191,6 @@ class DDPGConfig(SimpleQConfig): use_huber: Conventionally, no need to clip gradients if using a huber loss huber_threshold: Threshold of a huber loss l2_reg: Weights for L2 regularization - worker_side_prioritization: Whether to compute priorities on workers. training_intensity: The intensity with which to update the model (vs collecting samples from the env). If None, uses the "natural" value of: @@ -246,8 +246,6 @@ class DDPGConfig(SimpleQConfig): self.huber_threshold = huber_threshold if l2_reg is not None: self.l2_reg = l2_reg - if worker_side_prioritization is not None: - self.worker_side_prioritization = worker_side_prioritization if training_intensity is not None: self.training_intensity = training_intensity diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 8091ce267..1cb895337 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -135,7 +135,7 @@ class DQNConfig(SimpleQConfig): self.before_learn_on_batch = None self.training_intensity = None - # Changes to SimpleQConfig default + # Changes to SimpleQConfig's default: self.replay_buffer_config = { "type": "MultiAgentPrioritizedReplayBuffer", # Specify prioritized replay by supplying a buffer type that supports @@ -176,7 +176,6 @@ class DQNConfig(SimpleQConfig): Type[MultiAgentBatch], ] = None, training_intensity: Optional[float] = None, - worker_side_prioritization: Optional[bool] = None, replay_buffer_config: Optional[dict] = None, **kwargs, ) -> "DQNConfig": @@ -213,7 +212,6 @@ class DQNConfig(SimpleQConfig): -> will make sure that replay+train op will be executed 4x asoften as rollout+insert op (4 * 250 = 1000). See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details. - worker_side_prioritization: Whether to compute priorities on workers. replay_buffer_config: Replay buffer config. Examples: { @@ -278,11 +276,11 @@ class DQNConfig(SimpleQConfig): self.before_learn_on_batch = before_learn_on_batch if training_intensity is not None: self.training_intensity = training_intensity - if worker_side_prioritization is not None: - self.worker_side_priorizatiion = worker_side_prioritization if replay_buffer_config is not None: self.replay_buffer_config = replay_buffer_config + return self + # Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead! class _deprecated_default_config(dict): diff --git a/rllib/algorithms/qmix/qmix.py b/rllib/algorithms/qmix/qmix.py index 24c48cc62..f46c2083b 100644 --- a/rllib/algorithms/qmix/qmix.py +++ b/rllib/algorithms/qmix/qmix.py @@ -71,7 +71,6 @@ class QMixConfig(SimpleQConfig): self.optim_alpha = 0.99 self.optim_eps = 0.00001 self.grad_norm_clipping = 10 - self.worker_side_prioritization = False # Override some of TrainerConfig's default values with QMix-specific values. # .training() @@ -136,6 +135,8 @@ class QMixConfig(SimpleQConfig): # __sphinx_doc_end__ # fmt: on + self.worker_side_prioritization = DEPRECATED_VALUE + @override(SimpleQConfig) def training( self, @@ -148,7 +149,6 @@ class QMixConfig(SimpleQConfig): optim_alpha: Optional[float] = None, optim_eps: Optional[float] = None, grad_norm_clipping: Optional[float] = None, - worker_side_prioritization: Optional[bool] = None, **kwargs, ) -> "QMixConfig": """Sets the training related configuration. @@ -164,8 +164,6 @@ class QMixConfig(SimpleQConfig): optim_eps: RMSProp epsilon. grad_norm_clipping: If not None, clip gradients during optimization at this value. - worker_side_prioritization: Whether to compute priorities for the replay - buffer on worker side. Returns: This updated TrainerConfig object. @@ -189,8 +187,6 @@ class QMixConfig(SimpleQConfig): self.optim_eps = optim_eps if grad_norm_clipping is not None: self.grad_norm_clipping = grad_norm_clipping - if worker_side_prioritization is not None: - self.worker_side_prioritization = worker_side_prioritization return self diff --git a/rllib/algorithms/sac/__init__.py b/rllib/algorithms/sac/__init__.py index 2004cef5b..54228a312 100644 --- a/rllib/algorithms/sac/__init__.py +++ b/rllib/algorithms/sac/__init__.py @@ -1,4 +1,4 @@ -from ray.rllib.algorithms.sac.sac import SACTrainer, DEFAULT_CONFIG +from ray.rllib.algorithms.sac.sac import SACTrainer, DEFAULT_CONFIG, SACConfig from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy from ray.rllib.algorithms.sac.sac_torch_policy import SACTorchPolicy @@ -6,14 +6,16 @@ from ray.rllib.algorithms.sac.rnnsac import ( RNNSACTrainer, DEFAULT_CONFIG as RNNSAC_DEFAULT_CONFIG, ) -from ray.rllib.algorithms.sac.rnnsac import RNNSACTorchPolicy +from ray.rllib.algorithms.sac.rnnsac import RNNSACTorchPolicy, RNNSACConfig __all__ = [ "DEFAULT_CONFIG", "SACTFPolicy", "SACTorchPolicy", "SACTrainer", + "SACConfig", "RNNSAC_DEFAULT_CONFIG", "RNNSACTorchPolicy", "RNNSACTrainer", + "RNNSACConfig", ] diff --git a/rllib/algorithms/sac/rnnsac.py b/rllib/algorithms/sac/rnnsac.py index d0abd3bdf..5ee181ce0 100644 --- a/rllib/algorithms/sac/rnnsac.py +++ b/rllib/algorithms/sac/rnnsac.py @@ -1,50 +1,78 @@ -from typing import Type +from typing import Type, Optional -from ray.rllib.algorithms.sac import SACTrainer, DEFAULT_CONFIG as SAC_DEFAULT_CONFIG +from ray.rllib.algorithms.sac import ( + SACTrainer, + SACConfig, +) from ray.rllib.algorithms.sac.rnnsac_torch_policy import RNNSACTorchPolicy from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import TrainerConfigDict -from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated -DEFAULT_CONFIG = SACTrainer.merge_trainer_configs( - SAC_DEFAULT_CONFIG, - { - # Batch mode (see common config) - "batch_mode": "complete_episodes", - # If True, assume a zero-initialized state input (no matter where in - # the episode the sequence is located). - # If False, store the initial states along with each SampleBatch, use - # it (as initial state when running through the network for training), - # and update that initial state during training (from the internal - # state outputs of the immediately preceding sequence). - "zero_init_states": True, - "replay_buffer_config": { - # If > 0, use the `burn_in` first steps of each replay-sampled sequence - # (starting either from all 0.0-values if `zero_init_state=True` or - # from the already stored values) to calculate an even more accurate - # initial states for the actual sequence (starting after this burn-in - # window). In the burn-in case, the actual length of the sequence - # used for loss calculation is `n - burn_in` time steps - # (n=LSTM’s/attention net’s max_seq_len). - "replay_burn_in": 0, - # Set automatically: The number of contiguous environment steps to - # replay at once. Will be calculated via - # model->max_seq_len + burn_in. - # Do not set this to any valid value! - "replay_sequence_length": -1, - }, - "burn_in": DEPRECATED_VALUE, - }, - _allow_unknown_configs=True, -) + +class RNNSACConfig(SACConfig): + """Defines a configuration class from which an RNNSACTrainer can be built. + + Example: + >>> config = RNNSACConfig().training(gamma=0.9, lr=0.01)\ + ... .resources(num_gpus=0)\ + ... .rollouts(num_rollout_workers=4) + >>> print(config.to_dict()) + >>> # Build a Trainer object from the config and run 1 training iteration. + >>> trainer = config.build(env="CartPole-v1") + >>> trainer.train() + """ + + def __init__(self, trainer_class=None): + super().__init__(trainer_class=trainer_class or RNNSACTrainer) + # fmt: off + # __sphinx_doc_begin__ + self.burn_in = DEPRECATED_VALUE + self.batch_mode = "complete_episodes" + self.zero_init_states = True + self.replay_buffer_config["replay_burn_in"] = 0 + # Set automatically: The number of contiguous environment steps to + # replay at once. Will be calculated via + # model->max_seq_len + burn_in. + # Do not set this to any valid value! + self.replay_buffer_config["replay_sequence_length"] = -1 + + # fmt: on + # __sphinx_doc_end__ + + @override(SACConfig) + def training( + self, + *, + zero_init_states: Optional[bool] = None, + **kwargs, + ) -> "RNNSACConfig": + """Sets the training related configuration. + + Args: + zero_init_states: If True, assume a zero-initialized state input (no matter + where in the episode the sequence is located). + If False, store the initial states along with each SampleBatch, use + it (as initial state when running through the network for training), + and update that initial state during training (from the internal + state outputs of the immediately preceding sequence). + + Returns: + This updated TrainerConfig object. + """ + super().training(**kwargs) + if zero_init_states is not None: + self.zero_init_states = zero_init_states + + return self class RNNSACTrainer(SACTrainer): @classmethod @override(SACTrainer) def get_default_config(cls) -> TrainerConfigDict: - return DEFAULT_CONFIG + return RNNSACConfig().to_dict() @override(SACTrainer) def validate_config(self, config: TrainerConfigDict) -> None: @@ -81,3 +109,19 @@ class RNNSACTrainer(SACTrainer): @override(SACTrainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: return RNNSACTorchPolicy + + +class _deprecated_default_config(dict): + def __init__(self): + super().__init__(RNNSACConfig().to_dict()) + + @Deprecated( + old="ray.rllib.algorithms.sac.rnnsac.DEFAULT_CONFIG", + new="ray.rllib.algorithms.sac.rnnsac.RNNSACConfig(...)", + error=False, + ) + def __getitem__(self, item): + return super().__getitem__(item) + + +DEFAULT_CONFIG = _deprecated_default_config() diff --git a/rllib/algorithms/sac/rnnsac_torch_policy.py b/rllib/algorithms/sac/rnnsac_torch_policy.py index 220f22f78..ed19aee36 100644 --- a/rllib/algorithms/sac/rnnsac_torch_policy.py +++ b/rllib/algorithms/sac/rnnsac_torch_policy.py @@ -45,12 +45,12 @@ def build_rnnsac_model( num_outputs = int(np.product(obs_space.shape)) # Force-ignore any additionally provided hidden layer sizes. - # Everything should be configured using SAC's "Q_model" and "policy_model" - # settings. + # Everything should be configured using SAC's `q_model_config` and + # `policy_model_config` config settings. policy_model_config = MODEL_DEFAULTS.copy() - policy_model_config.update(config["policy_model"]) + policy_model_config.update(config["policy_model_config"]) q_model_config = MODEL_DEFAULTS.copy() - q_model_config.update(config["Q_model"]) + q_model_config.update(config["q_model_config"]) default_model_cls = RNNSACTorchModel diff --git a/rllib/algorithms/sac/sac.py b/rllib/algorithms/sac/sac.py index 3f195bbd7..ca4696e82 100644 --- a/rllib/algorithms/sac/sac.py +++ b/rllib/algorithms/sac/sac.py @@ -1,12 +1,16 @@ import logging -from typing import Type +from typing import Type, Dict, Any, Optional, Union -from ray.rllib.agents.trainer import with_common_config from ray.rllib.algorithms.dqn.dqn import DQNTrainer from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override -from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.agents.trainer_config import TrainerConfig +from ray.rllib.utils.deprecation import ( + DEPRECATED_VALUE, + deprecation_warning, + Deprecated, +) from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.typing import TrainerConfigDict @@ -15,160 +19,255 @@ tfp = try_import_tfp() logger = logging.getLogger(__name__) -# fmt: off -# __sphinx_doc_begin__ -# Adds the following updates to the (base) `Trainer` config in -# rllib/agents/trainer.py (`COMMON_CONFIG` dict). -DEFAULT_CONFIG = with_common_config({ - # === Model === - # Use two Q-networks (instead of one) for action-value estimation. - # Note: Each Q-network will have its own target network. - "twin_q": True, - # Use a e.g. conv2D state preprocessing network before concatenating the - # resulting (feature) vector with the action input for the input to - # the Q-networks. - "use_state_preprocessor": DEPRECATED_VALUE, - # Model options for the Q network(s). These will override MODEL_DEFAULTS. - # The `Q_model` dict is treated just as the top-level `model` dict in - # setting up the Q-network(s) (2 if twin_q=True). - # That means, you can do for different observation spaces: - # obs=Box(1D) -> Tuple(Box(1D) + Action) -> concat -> post_fcnet - # obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action - # -> post_fcnet - # obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action) - # -> vision-net -> concat w/ Box(1D) and action -> post_fcnet - # You can also have SAC use your custom_model as Q-model(s), by simply - # specifying the `custom_model` sub-key in below dict (just like you would - # do in the top-level `model` dict. - "Q_model": { - "fcnet_hiddens": [256, 256], - "fcnet_activation": "relu", - "post_fcnet_hiddens": [], - "post_fcnet_activation": None, - "custom_model": None, # Use this to define custom Q-model(s). - "custom_model_config": {}, - }, - # Model options for the policy function (see `Q_model` above for details). - # The difference to `Q_model` above is that no action concat'ing is - # performed before the post_fcnet stack. - "policy_model": { - "fcnet_hiddens": [256, 256], - "fcnet_activation": "relu", - "post_fcnet_hiddens": [], - "post_fcnet_activation": None, - "custom_model": None, # Use this to define a custom policy model. - "custom_model_config": {}, - }, - # Actions are already normalized, no need to clip them further. - "clip_actions": False, +class SACConfig(TrainerConfig): + """Defines a configuration class from which an SACTrainer can be built. - # === Learning === - # Update the target by \tau * policy + (1-\tau) * target_policy. - "tau": 5e-3, - # Initial value to use for the entropy weight alpha. - "initial_alpha": 1.0, - # Target entropy lower bound. If "auto", will be set to -|A| (e.g. -2.0 for - # Discrete(2), -3.0 for Box(shape=(3,))). - # This is the inverse of reward scale, and will be optimized automatically. - "target_entropy": "auto", - # N-step target updates. If >1, sars' tuples in trajectories will be - # postprocessed to become sa[discounted sum of R][s t+n] tuples. - "n_step": 1, - # Minimum env sampling timesteps to accumulate within a single `train()` call. This - # value does not affect learning, only the number of times `Trainer.step_attempt()` - # is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling - # timestep count has not been reached, will perform n more `step_attempt()` calls - # until the minimum timesteps have been executed. Set to 0 for no minimum timesteps. - "min_sample_timesteps_per_reporting": 100, + Example: + >>> config = SACConfig().training(gamma=0.9, lr=0.01)\ + ... .resources(num_gpus=0)\ + ... .rollouts(num_rollout_workers=4) + >>> print(config.to_dict()) + >>> # Build a Trainer object from the config and run 1 training iteration. + >>> trainer = config.build(env="CartPole-v1") + >>> trainer.train() + """ - # === Replay buffer === - "replay_buffer_config": { - "type": "MultiAgentReplayBuffer", - # Specify prioritized replay by supplying a buffer type that supports - # prioritization, for example: MultiAgentPrioritizedReplayBuffer. - "prioritized_replay": DEPRECATED_VALUE, - "capacity": int(1e6), - # How many steps of the model to sample before learning starts. - "learning_starts": 1500, - # The number of continuous environment steps to replay at once. This may - # be set to greater than 1 to support recurrent models. - "replay_sequence_length": 1, - "prioritized_replay_alpha": 0.6, - # Beta parameter for sampling from prioritized replay buffer. - "prioritized_replay_beta": 0.4, - # Epsilon to add to the TD errors when updating priorities. - "prioritized_replay_eps": 1e-6, - # Whether to compute priorities on workers. - "worker_side_prioritization": False, - }, - # Set this to True, if you want the contents of your buffer(s) to be - # stored in any saved checkpoints as well. - # Warnings will be created if: - # - This is True AND restoring from a checkpoint that contains no buffer - # data. - # - This is False AND restoring from a checkpoint that does contain - # buffer data. - "store_buffer_in_checkpoints": False, - # Whether to LZ4 compress observations - "compress_observations": False, + def __init__(self, trainer_class=None): + super().__init__(trainer_class=trainer_class or SACTrainer) + # fmt: off + # __sphinx_doc_begin__ + # SAC-specific config settings. + self.twin_q = True + self.q_model_config = { + "fcnet_hiddens": [256, 256], + "fcnet_activation": "relu", + "post_fcnet_hiddens": [], + "post_fcnet_activation": None, + "custom_model": None, # Use this to define custom Q-model(s). + "custom_model_config": {}, + } + self.policy_model_config = { + "fcnet_hiddens": [256, 256], + "fcnet_activation": "relu", + "post_fcnet_hiddens": [], + "post_fcnet_activation": None, + "custom_model": None, # Use this to define a custom policy model. + "custom_model_config": {}, + } + self.clip_actions = False + self.tau = 5e-3 + self.initial_alpha = 1.0 + self.target_entropy = "auto" + self.n_step = 1 + self.replay_buffer_config = { + "_enable_replay_buffer_api": True, + "type": "MultiAgentPrioritizedReplayBuffer", + "capacity": int(1e6), + # How many steps of the model to sample before learning starts. + "learning_starts": 1500, + # If True prioritized replay buffer will be used. + "prioritized_replay": False, + "prioritized_replay_alpha": 0.6, + "prioritized_replay_beta": 0.4, + "prioritized_replay_eps": 1e-6, + # Whether to compute priorities already on the remote worker side. + "worker_side_prioritization": False, + } + self.store_buffer_in_checkpoints = False + self.training_intensity = None + self.optimization = { + "actor_learning_rate": 3e-4, + "critic_learning_rate": 3e-4, + "entropy_learning_rate": 3e-4, + } + self.grad_clip = None + self.target_network_update_freq = 0 - # The intensity with which to update the model (vs collecting samples from - # the env). If None, uses the "natural" value of: - # `train_batch_size` / (`rollout_fragment_length` x `num_workers` x - # `num_envs_per_worker`). - # If provided, will make sure that the ratio between ts inserted into and - # sampled from the buffer matches the given value. - # Example: - # training_intensity=1000.0 - # train_batch_size=250 rollout_fragment_length=1 - # num_workers=1 (or 0) num_envs_per_worker=1 - # -> natural value = 250 / 1 = 250.0 - # -> will make sure that replay+train op will be executed 4x as - # often as rollout+insert op (4 * 250 = 1000). - # See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details. - "training_intensity": None, + # .rollout() + self.rollout_fragment_length = 1 + self.compress_observations = False - # === Optimization === - "optimization": { - "actor_learning_rate": 3e-4, - "critic_learning_rate": 3e-4, - "entropy_learning_rate": 3e-4, - }, - # If not None, clip gradients during optimization at this value. - "grad_clip": None, - # Update the replay buffer with this many samples at once. Note that this - # setting applies per-worker if num_workers > 1. - "rollout_fragment_length": 1, - # Size of a batched sampled from replay buffer for training. - "train_batch_size": 256, - # Update the target network every `target_network_update_freq` sample steps. - "target_network_update_freq": 0, + # .training() + self.train_batch_size = 256 - # === Parallelism === - # Whether to use a GPU for local optimization. - "num_gpus": 0, - # Number of workers for collecting samples with. This only makes sense - # to increase if your environment is particularly slow to sample, or if - # you"re using the Async or Ape-X optimizers. - "num_workers": 0, - # Whether to allocate GPUs for workers (if > 0). - "num_gpus_per_worker": 0, - # Whether to allocate CPUs for workers (if > 0). - "num_cpus_per_worker": 1, - # Prevent reporting frequency from going lower than this time span. - "min_time_s_per_reporting": 1, + # .reporting() + self.min_time_s_per_reporting = 1 + self.min_sample_timesteps_per_reporting = 100 + # __sphinx_doc_end__ + # fmt: on - # Whether the loss should be calculated deterministically (w/o the - # stochastic action sampling step). True only useful for cont. actions and - # for debugging! - "_deterministic_loss": False, - # Use a Beta-distribution instead of a SquashedGaussian for bounded, - # continuous action spaces (not recommended, for debugging only). - "_use_beta_distribution": False, -}) -# __sphinx_doc_end__ -# fmt: on + self._deterministic_loss = False + self._use_beta_distribution = False + + self.use_state_preprocessor = DEPRECATED_VALUE + self.worker_side_prioritization = DEPRECATED_VALUE + + @override(TrainerConfig) + def training( + self, + *, + twin_q: Optional[bool] = None, + q_model_config: Optional[Dict[str, Any]] = None, + policy_model_config: Optional[Dict[str, Any]] = None, + tau: Optional[float] = None, + initial_alpha: Optional[float] = None, + target_entropy: Optional[Union[str, float]] = None, + n_step: Optional[int] = None, + store_buffer_in_checkpoints: Optional[bool] = None, + replay_buffer_config: Optional[Dict[str, Any]] = None, + training_intensity: Optional[float] = None, + clip_actions: Optional[bool] = None, + grad_clip: Optional[float] = None, + optimization_config: Optional[Dict[str, Any]] = None, + target_network_update_freq: Optional[int] = None, + _deterministic_loss: Optional[bool] = None, + _use_beta_distribution: Optional[bool] = None, + **kwargs, + ) -> "SACConfig": + """Sets the training related configuration. + + Args: + twin_q: Use two Q-networks (instead of one) for action-value estimation. + Note: Each Q-network will have its own target network. + q_model_config: Model configs for the Q network(s). These will override + MODEL_DEFAULTS. This is treated just as the top-level `model` dict in + setting up the Q-network(s) (2 if twin_q=True). + That means, you can do for different observation spaces: + obs=Box(1D) -> Tuple(Box(1D) + Action) -> concat -> post_fcnet + obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action + -> post_fcnet + obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action) + -> vision-net -> concat w/ Box(1D) and action -> post_fcnet + You can also have SAC use your custom_model as Q-model(s), by simply + specifying the `custom_model` sub-key in below dict (just like you would + do in the top-level `model` dict. + policy_model_config: Model options for the policy function (see + `q_model_config` above for details). The difference to `q_model_config` + above is that no action concat'ing is performed before the post_fcnet + stack. + tau: Update the target by \tau * policy + (1-\tau) * target_policy. + initial_alpha: Initial value to use for the entropy weight alpha. + target_entropy: Target entropy lower bound. If "auto", will be set + to -|A| (e.g. -2.0 for Discrete(2), -3.0 for Box(shape=(3,))). + This is the inverse of reward scale, and will be optimized + automatically. + n_step: N-step target updates. If >1, sars' tuples in trajectories will be + postprocessed to become sa[discounted sum of R][s t+n] tuples. + store_buffer_in_checkpoints: Set this to True, if you want the contents of + your buffer(s) to be stored in any saved checkpoints as well. + Warnings will be created if: + - This is True AND restoring from a checkpoint that contains no buffer + data. + - This is False AND restoring from a checkpoint that does contain + buffer data. + replay_buffer_config: Replay buffer config. + Examples: + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentReplayBuffer", + "learning_starts": 1000, + "capacity": 50000, + "replay_batch_size": 32, + "replay_sequence_length": 1, + } + - OR - + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentPrioritizedReplayBuffer", + "capacity": 50000, + "prioritized_replay_alpha": 0.6, + "prioritized_replay_beta": 0.4, + "prioritized_replay_eps": 1e-6, + "replay_sequence_length": 1, + } + - Where - + prioritized_replay_alpha: Alpha parameter controls the degree of + prioritization in the buffer. In other words, when a buffer sample has + a higher temporal-difference error, with how much more probability + should it drawn to use to update the parametrized Q-network. 0.0 + corresponds to uniform probability. Setting much above 1.0 may quickly + result as the sampling distribution could become heavily “pointy” with + low entropy. + prioritized_replay_beta: Beta parameter controls the degree of + importance sampling which suppresses the influence of gradient updates + from samples that have higher probability of being sampled via alpha + parameter and the temporal-difference error. + prioritized_replay_eps: Epsilon parameter sets the baseline probability + for sampling so that when the temporal-difference error of a sample is + zero, there is still a chance of drawing the sample. + training_intensity: The intensity with which to update the model (vs + collecting samples from the env). + If None, uses "natural" values of: + `train_batch_size` / (`rollout_fragment_length` x `num_workers` x + `num_envs_per_worker`). + If not None, will make sure that the ratio between timesteps inserted + into and sampled from th buffer matches the given values. + Example: + training_intensity=1000.0 + train_batch_size=250 + rollout_fragment_length=1 + num_workers=1 (or 0) + num_envs_per_worker=1 + -> natural value = 250 / 1 = 250.0 + -> will make sure that replay+train op will be executed 4x asoften as + rollout+insert op (4 * 250 = 1000). + See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details. + clip_actions: Whether to clip actions. If actions are already normalized, + this should be set to False. + grad_clip: If not None, clip gradients during optimization at this value. + optimization_config: Config dict for optimization. Set the supported keys + `actor_learning_rate`, `critic_learning_rate`, and + `entropy_learning_rate` in here. + target_network_update_freq: Update the target network every + `target_network_update_freq` steps. + _deterministic_loss: Whether the loss should be calculated deterministically + (w/o the stochastic action sampling step). True only useful for + continuous actions and for debugging. + _use_beta_distribution: Use a Beta-distribution instead of a + `SquashedGaussian` for bounded, continuous action spaces (not + recommended; for debugging only). + + Returns: + This updated TrainerConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if twin_q is not None: + self.twin_q = twin_q + if q_model_config is not None: + self.q_model_config = q_model_config + if policy_model_config is not None: + self.policy_model_config = policy_model_config + if tau is not None: + self.tau = tau + if initial_alpha is not None: + self.initial_alpha = initial_alpha + if target_entropy is not None: + self.target_entropy = target_entropy + if n_step is not None: + self.n_step = n_step + if store_buffer_in_checkpoints is not None: + self.store_buffer_in_checkpoints = store_buffer_in_checkpoints + if replay_buffer_config is not None: + self.replay_buffer_config = replay_buffer_config + if training_intensity is not None: + self.training_intensity = training_intensity + if clip_actions is not None: + self.clip_actions = clip_actions + if grad_clip is not None: + self.grad_clip = grad_clip + if optimization_config is not None: + self.optimization_config = optimization_config + if target_network_update_freq is not None: + self.target_network_update_freq = target_network_update_freq + if _deterministic_loss is not None: + self._deterministic_loss = _deterministic_loss + if _use_beta_distribution is not None: + self._use_beta_distribution = _use_beta_distribution + + return self class SACTrainer(DQNTrainer): @@ -183,13 +282,13 @@ class SACTrainer(DQNTrainer): """ def __init__(self, *args, **kwargs): - self._allow_unknown_subkeys += ["policy_model", "Q_model"] + self._allow_unknown_subkeys += ["policy_model_config", "q_model_config"] super().__init__(*args, **kwargs) @classmethod @override(DQNTrainer) def get_default_config(cls) -> TrainerConfigDict: - return DEFAULT_CONFIG + return SACConfig().to_dict() @override(DQNTrainer) def validate_config(self, config: TrainerConfigDict) -> None: @@ -200,6 +299,22 @@ class SACTrainer(DQNTrainer): deprecation_warning(old="config['use_state_preprocessor']", error=False) config["use_state_preprocessor"] = DEPRECATED_VALUE + if config.get("policy_model", DEPRECATED_VALUE) != DEPRECATED_VALUE: + deprecation_warning( + old="config['policy_model']", + new="config['policy_model_config']", + error=False, + ) + config["policy_model_config"] = config["policy_model"] + + if config.get("Q_model", DEPRECATED_VALUE) != DEPRECATED_VALUE: + deprecation_warning( + old="config['Q_model']", + new="config['q_model_config']", + error=False, + ) + config["q_model_config"] = config["Q_model"] + if config["grad_clip"] is not None and config["grad_clip"] <= 0.0: raise ValueError("`grad_clip` value must be > 0.0!") @@ -220,3 +335,20 @@ class SACTrainer(DQNTrainer): return SACTorchPolicy else: return SACTFPolicy + + +# Deprecated: Use ray.rllib.algorithms.sac.SACConfig instead! +class _deprecated_default_config(dict): + def __init__(self): + super().__init__(SACConfig().to_dict()) + + @Deprecated( + old="ray.rllib.algorithms.sac.sac.DEFAULT_CONFIG", + new="ray.rllib.algorithms.sac.sac.SACConfig(...)", + error=False, + ) + def __getitem__(self, item): + return super().__getitem__(item) + + +DEFAULT_CONFIG = _deprecated_default_config() diff --git a/rllib/algorithms/sac/sac_tf_model.py b/rllib/algorithms/sac/sac_tf_model.py index dc9201d4d..73e25cc4f 100644 --- a/rllib/algorithms/sac/sac_tf_model.py +++ b/rllib/algorithms/sac/sac_tf_model.py @@ -19,9 +19,9 @@ class SACTFModel(TFModelV2): To customize, do one of the following: - sub-class SACTFModel and override one or more of its methods. - - Use SAC's `Q_model` and `policy_model` keys to tweak the default model + - Use SAC's `q_model_config` and `policy_model` keys to tweak the default model behaviors (e.g. fcnet_hiddens, conv_filters, etc..). - - Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys + - Use SAC's `q_model_config->custom_model` and `policy_model->custom_model` keys to specify your own custom Q-model(s) and policy-models, which will be created within this SACTFModel (see `build_policy_model` and `build_q_model`. @@ -160,7 +160,7 @@ class SACTFModel(TFModelV2): Override this method in a sub-class of SACTFModel to implement your own Q-nets. Alternatively, simply set `custom_model` within the - top level SAC `Q_model` config key to make this default implementation + top level SAC `q_model_config` config key to make this default implementation of `build_q_model` use your custom Q-nets. Returns: diff --git a/rllib/algorithms/sac/sac_tf_policy.py b/rllib/algorithms/sac/sac_tf_policy.py index 5b409f542..aadae5f58 100644 --- a/rllib/algorithms/sac/sac_tf_policy.py +++ b/rllib/algorithms/sac/sac_tf_policy.py @@ -72,12 +72,12 @@ def build_sac_model( `policy.target_model`. """ # Force-ignore any additionally provided hidden layer sizes. - # Everything should be configured using SAC's "Q_model" and "policy_model" - # settings. + # Everything should be configured using SAC's `q_model_config` and + # `policy_model_config` config settings. policy_model_config = copy.deepcopy(MODEL_DEFAULTS) - policy_model_config.update(config["policy_model"]) + policy_model_config.update(config["policy_model_config"]) q_model_config = copy.deepcopy(MODEL_DEFAULTS) - q_model_config.update(config["Q_model"]) + q_model_config.update(config["q_model_config"]) default_model_cls = SACTorchModel if config["framework"] == "torch" else SACTFModel diff --git a/rllib/algorithms/sac/sac_torch_model.py b/rllib/algorithms/sac/sac_torch_model.py index 68e734baa..b97f95272 100644 --- a/rllib/algorithms/sac/sac_torch_model.py +++ b/rllib/algorithms/sac/sac_torch_model.py @@ -19,9 +19,9 @@ class SACTorchModel(TorchModelV2, nn.Module): To customize, do one of the following: - sub-class SACTorchModel and override one or more of its methods. - - Use SAC's `Q_model` and `policy_model` keys to tweak the default model + - Use SAC's `q_model_config` and `policy_model` keys to tweak the default model behaviors (e.g. fcnet_hiddens, conv_filters, etc..). - - Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys + - Use SAC's `q_model_config->custom_model` and `policy_model->custom_model` keys to specify your own custom Q-model(s) and policy-models, which will be created within this SACTFModel (see `build_policy_model` and `build_q_model`. @@ -168,7 +168,7 @@ class SACTorchModel(TorchModelV2, nn.Module): Override this method in a sub-class of SACTFModel to implement your own Q-nets. Alternatively, simply set `custom_model` within the - top level SAC `Q_model` config key to make this default implementation + top level SAC `q_model_config` config key to make this default implementation of `build_q_model` use your custom Q-nets. Returns: diff --git a/rllib/algorithms/sac/tests/test_rnnsac.py b/rllib/algorithms/sac/tests/test_rnnsac.py index e5d8a53ee..3f93d102f 100644 --- a/rllib/algorithms/sac/tests/test_rnnsac.py +++ b/rllib/algorithms/sac/tests/test_rnnsac.py @@ -1,7 +1,7 @@ import unittest import ray -import ray.rllib.algorithms.sac as sac +from ray.rllib.algorithms import sac from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_compute_single_action, framework_iterator @@ -20,42 +20,39 @@ class TestRNNSAC(unittest.TestCase): def test_rnnsac_compilation(self): """Test whether a R2D2Trainer can be built on all frameworks.""" - config = sac.RNNSAC_DEFAULT_CONFIG.copy() - config["num_workers"] = 0 # Run locally. - - # Wrap with an LSTM and use a very simple base-model. - config["model"] = { - "max_seq_len": 20, - } - config["policy_model"] = { - "use_lstm": True, - "lstm_cell_size": 64, - "fcnet_hiddens": [10], - "lstm_use_prev_action": True, - "lstm_use_prev_reward": True, - } - config["Q_model"] = { - "use_lstm": True, - "lstm_cell_size": 64, - "fcnet_hiddens": [10], - "lstm_use_prev_action": True, - "lstm_use_prev_reward": True, - } - - # Test with MultiAgentPrioritizedReplayBuffer - config["replay_buffer_config"] = { - "type": "MultiAgentPrioritizedReplayBuffer", - "replay_burn_in": 20, - "zero_init_states": True, - } - - config["lr"] = 5e-4 - + config = ( + sac.RNNSACConfig() + .rollouts(num_rollout_workers=0) + .training( + # Wrap with an LSTM and use a very simple base-model. + model={"max_seq_len": 20}, + policy_model_config={ + "use_lstm": True, + "lstm_cell_size": 64, + "fcnet_hiddens": [10], + "lstm_use_prev_action": True, + "lstm_use_prev_reward": True, + }, + q_model_config={ + "use_lstm": True, + "lstm_cell_size": 64, + "fcnet_hiddens": [10], + "lstm_use_prev_action": True, + "lstm_use_prev_reward": True, + }, + replay_buffer_config={ + "type": "MultiAgentPrioritizedReplayBuffer", + "replay_burn_in": 20, + "zero_init_states": True, + }, + lr=5e-4, + ) + ) num_iterations = 1 # Test building an RNNSAC agent in all frameworks. for _ in framework_iterator(config, frameworks="torch"): - trainer = sac.RNNSACTrainer(config=config, env="CartPole-v0") + trainer = config.build(env="CartPole-v0") for i in range(num_iterations): results = trainer.train() print(results) diff --git a/rllib/algorithms/sac/tests/test_sac.py b/rllib/algorithms/sac/tests/test_sac.py index 302fb88b7..54221b44c 100644 --- a/rllib/algorithms/sac/tests/test_sac.py +++ b/rllib/algorithms/sac/tests/test_sac.py @@ -6,7 +6,7 @@ import re import unittest import ray -import ray.rllib.algorithms.sac as sac +from ray.rllib.algorithms import sac from ray.rllib.algorithms.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss from ray.rllib.algorithms.sac.sac_torch_policy import actor_critic_loss as loss_torch from ray.rllib.examples.env.random_env import RandomEnv @@ -74,20 +74,17 @@ class TestSAC(unittest.TestCase): def test_sac_compilation(self): """Tests whether an SACTrainer can be built with all frameworks.""" - config = sac.DEFAULT_CONFIG.copy() - config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy() - config["num_workers"] = 0 # Run locally. - config["n_step"] = 3 - config["twin_q"] = True - config["replay_buffer_config"]["learning_starts"] = 0 - config["rollout_fragment_length"] = 10 - config["train_batch_size"] = 10 - # If we use default buffer size (1e6), the buffer will take up - # 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021) - # available system memory (8.34816 GB). - config["replay_buffer_config"]["capacity"] = 40000 - # Test with saved replay buffer. - config["store_buffer_in_checkpoints"] = True + config = ( + sac.SACConfig() + .training( + n_step=3, + twin_q=True, + replay_buffer_config={"learning_starts": 0, "capacity": 40000}, + store_buffer_in_checkpoints=True, + train_batch_size=10, + ) + .rollouts(num_rollout_workers=0, rollout_fragment_length=10) + ) num_iterations = 1 ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel) @@ -134,12 +131,12 @@ class TestSAC(unittest.TestCase): print("Env={}".format(env)) # Test making the Q-model a custom one for CartPole, otherwise, # use the default model. - config["Q_model"]["custom_model"] = ( + config.q_model_config["custom_model"] = ( "batch_norm{}".format("_torch" if fw == "torch" else "") if env == "CartPole-v0" else None ) - trainer = sac.SACTrainer(config=config, env=env) + trainer = config.build(env=env) for i in range(num_iterations): results = trainer.train() check_train_results(results) @@ -167,22 +164,25 @@ class TestSAC(unittest.TestCase): def test_sac_loss_function(self): """Tests SAC loss function results across all frameworks.""" - config = sac.DEFAULT_CONFIG.copy() - # Run locally. - config["seed"] = 42 - config["num_workers"] = 0 - config["replay_buffer_config"]["learning_starts"] = 0 - config["twin_q"] = False - config["gamma"] = 0.99 - # Switch on deterministic loss so we can compare the loss values. - config["_deterministic_loss"] = True - # Use very simple nets. - config["Q_model"]["fcnet_hiddens"] = [10] - config["policy_model"]["fcnet_hiddens"] = [10] - # Make sure, timing differences do not affect trainer.train(). - config["min_time_s_per_reporting"] = 0 - # Test SAC with Simplex action space. - config["env_config"] = {"simplex_actions": True} + config = ( + sac.SACConfig() + .training( + twin_q=False, + gamma=0.99, + _deterministic_loss=True, + q_model_config={"fcnet_hiddens": [10]}, + policy_model_config={"fcnet_hiddens": [10]}, + replay_buffer_config={"learning_starts": 0}, + ) + .rollouts(num_rollout_workers=0) + .reporting( + min_time_s_per_reporting=0, + ) + .environment( + env_config={"simplex_actions": True}, + ) + .debugging(seed=42) + ) map_ = { # Action net. @@ -247,7 +247,7 @@ class TestSAC(unittest.TestCase): config, frameworks=("tf", "torch"), session=True ): # Generate Trainer and get its default Policy object. - trainer = sac.SACTrainer(config=config, env=env) + trainer = config.build(env=env) policy = trainer.get_policy() p_sess = None if sess: @@ -287,7 +287,7 @@ class TestSAC(unittest.TestCase): sorted(weights_dict.keys()), log_alpha, fw, - gamma=config["gamma"], + gamma=config.gamma, sess=sess, ) @@ -520,19 +520,22 @@ class TestSAC(unittest.TestCase): return dict_samples[self.steps], 1, self.steps >= 5, {} tune.register_env("nested", lambda _: NestedDictEnv()) - - config = sac.DEFAULT_CONFIG.copy() - config["num_workers"] = 0 # Run locally. - config["replay_buffer_config"]["learning_starts"] = 0 - config["rollout_fragment_length"] = 5 - config["train_batch_size"] = 5 - config["replay_buffer_config"]["capacity"] = 10 - # Disable preprocessors. - config["_disable_preprocessor_api"] = True + config = ( + sac.SACConfig() + .training( + replay_buffer_config={"learning_starts": 0, "capacity": 10}, + train_batch_size=5, + ) + .rollouts( + num_rollout_workers=0, + rollout_fragment_length=5, + ) + .experimental(_disable_preprocessor_api=True) + ) num_iterations = 1 for _ in framework_iterator(config, with_eager_tracing=True): - trainer = sac.SACTrainer(env="nested", config=config) + trainer = config.build(env="nested") for _ in range(num_iterations): results = trainer.train() check_train_results(results) diff --git a/rllib/algorithms/slateq/slateq.py b/rllib/algorithms/slateq/slateq.py index e601f70f4..98ac6b99f 100644 --- a/rllib/algorithms/slateq/slateq.py +++ b/rllib/algorithms/slateq/slateq.py @@ -139,7 +139,6 @@ class SlateQConfig(TrainerConfig): rmsprop_epsilon: Optional[float] = None, grad_clip: Optional[float] = None, n_step: Optional[int] = None, - worker_side_prioritization: Optional[bool] = None, **kwargs, ) -> "SlateQConfig": """Sets the training related configuration. @@ -172,8 +171,6 @@ class SlateQConfig(TrainerConfig): rmsprop_epsilon: RMSProp epsilon hyperparameter. grad_clip: If not None, clip gradients during optimization at this value. n_step: N-step parameter for Q-learning. - worker_side_prioritization: Whether to compute priorities for the replay - buffer on the workers. Returns: This updated TrainerConfig object. @@ -205,8 +202,6 @@ class SlateQConfig(TrainerConfig): self.grad_clip = grad_clip if n_step is not None: self.n_step = n_step - if worker_side_prioritization is not None: - self.worker_side_prioritization = worker_side_prioritization return self diff --git a/rllib/examples/offline_rl.py b/rllib/examples/offline_rl.py index 96ce8841f..0052dd06b 100644 --- a/rllib/examples/offline_rl.py +++ b/rllib/examples/offline_rl.py @@ -29,7 +29,7 @@ if __name__ == "__main__": # See rllib/tuned_examples/cql/pendulum-cql.yaml for comparison. - config = cql.CQL_DEFAULT_CONFIG.copy() + config = cql.DEFAULT_CONFIG.copy() config["num_workers"] = 0 # Run locally. config["horizon"] = 200 config["soft_horizon"] = True @@ -45,11 +45,11 @@ if __name__ == "__main__": config["replay_buffer_config"]["capacity"] = int(1e6) config["tau"] = 0.005 config["target_entropy"] = "auto" - config["Q_model"] = { + config["q_model_config"] = { "fcnet_hiddens": [256, 256], "fcnet_activation": "relu", } - config["policy_model"] = { + config["policy_model_config"] = { "fcnet_hiddens": [256, 256], "fcnet_activation": "relu", } diff --git a/rllib/examples/rnnsac_stateless_cartpole.py b/rllib/examples/rnnsac_stateless_cartpole.py index 0a5fa2972..d79d74e11 100644 --- a/rllib/examples/rnnsac_stateless_cartpole.py +++ b/rllib/examples/rnnsac_stateless_cartpole.py @@ -55,14 +55,14 @@ config = { "model": { "max_seq_len": 20, }, - "policy_model": { + "policy_model_config": { "use_lstm": True, "lstm_cell_size": 64, "fcnet_hiddens": [64, 64], "lstm_use_prev_action": True, "lstm_use_prev_reward": True, }, - "Q_model": { + "q_model_config": { "use_lstm": True, "lstm_cell_size": 64, "fcnet_hiddens": [64, 64], diff --git a/rllib/policy/tests/test_compute_log_likelihoods.py b/rllib/policy/tests/test_compute_log_likelihoods.py index 5a29e7b78..c5ad1c539 100644 --- a/rllib/policy/tests/test_compute_log_likelihoods.py +++ b/rllib/policy/tests/test_compute_log_likelihoods.py @@ -177,8 +177,8 @@ class TestComputeLogLikelihood(unittest.TestCase): """Tests SAC's (cont. actions) compute_log_likelihoods method.""" config = sac.DEFAULT_CONFIG.copy() config["seed"] = 42 - config["policy_model"]["fcnet_hiddens"] = [10] - config["policy_model"]["fcnet_activation"] = "linear" + config["policy_model_config"]["fcnet_hiddens"] = [10] + config["policy_model_config"]["fcnet_activation"] = "linear" prev_a = np.array([0.0]) # SAC cont uses a squashed normal distribution. Implement it's logp @@ -210,8 +210,8 @@ class TestComputeLogLikelihood(unittest.TestCase): """Tests SAC's (discrete actions) compute_log_likelihoods method.""" config = sac.DEFAULT_CONFIG.copy() config["seed"] = 42 - config["policy_model"]["fcnet_hiddens"] = [10] - config["policy_model"]["fcnet_activation"] = "linear" + config["policy_model_config"]["fcnet_hiddens"] = [10] + config["policy_model_config"]["fcnet_activation"] = "linear" prev_a = np.array(0) do_test_log_likelihood(sac.SACTrainer, config, prev_a) diff --git a/rllib/tuned_examples/cql/halfcheetah-bc.yaml b/rllib/tuned_examples/cql/halfcheetah-bc.yaml index 4d9ec0298..3ccc51598 100644 --- a/rllib/tuned_examples/cql/halfcheetah-bc.yaml +++ b/rllib/tuned_examples/cql/halfcheetah-bc.yaml @@ -15,10 +15,10 @@ halfcheetah_bc: framework: torch soft_horizon: False horizon: 1000 - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256, 256] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256, 256] tau: 0.005 diff --git a/rllib/tuned_examples/cql/halfcheetah-cql.yaml b/rllib/tuned_examples/cql/halfcheetah-cql.yaml index 741fdbb87..17d85a32f 100644 --- a/rllib/tuned_examples/cql/halfcheetah-cql.yaml +++ b/rllib/tuned_examples/cql/halfcheetah-cql.yaml @@ -17,10 +17,10 @@ halfcheetah_cql: framework: tf soft_horizon: False horizon: 1000 - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256, 256] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256, 256] tau: 0.005 diff --git a/rllib/tuned_examples/cql/hopper-bc.yaml b/rllib/tuned_examples/cql/hopper-bc.yaml index 2351ecc5c..31ba978cd 100644 --- a/rllib/tuned_examples/cql/hopper-bc.yaml +++ b/rllib/tuned_examples/cql/hopper-bc.yaml @@ -15,10 +15,10 @@ hopper_bc: framework: torch soft_horizon: False horizon: 1000 - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256, 256] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256, 256] tau: 0.005 diff --git a/rllib/tuned_examples/cql/hopper-cql.yaml b/rllib/tuned_examples/cql/hopper-cql.yaml index 6dda3c7f6..dd32efe7d 100644 --- a/rllib/tuned_examples/cql/hopper-cql.yaml +++ b/rllib/tuned_examples/cql/hopper-cql.yaml @@ -15,10 +15,10 @@ hopper_cql: framework: torch soft_horizon: False horizon: 1000 - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256, 256] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256, 256] tau: 0.005 diff --git a/rllib/tuned_examples/ddpg/hopper-pybullet-ddpg.yaml b/rllib/tuned_examples/ddpg/hopper-pybullet-ddpg.yaml index 1a149aed6..d3112dc42 100644 --- a/rllib/tuned_examples/ddpg/hopper-pybullet-ddpg.yaml +++ b/rllib/tuned_examples/ddpg/hopper-pybullet-ddpg.yaml @@ -31,7 +31,7 @@ ddpg-hopperbulletenv-v0: prioritized_replay_alpha: 0.6 prioritized_replay_beta: 0.4 prioritized_replay_eps: 0.000001 - worker_side_prioritization: False + worker_side_prioritization: false learning_starts: 500 clip_rewards: False actor_lr: 0.001 diff --git a/rllib/tuned_examples/ddpg/mountaincarcontinuous-ddpg.yaml b/rllib/tuned_examples/ddpg/mountaincarcontinuous-ddpg.yaml index 9d2f08919..dc8388921 100644 --- a/rllib/tuned_examples/ddpg/mountaincarcontinuous-ddpg.yaml +++ b/rllib/tuned_examples/ddpg/mountaincarcontinuous-ddpg.yaml @@ -37,7 +37,7 @@ mountaincarcontinuous-ddpg: prioritized_replay_alpha: 0.6 prioritized_replay_beta: 0.4 prioritized_replay_eps: 0.000001 - worker_side_prioritization: False + worker_side_prioritization: false clip_rewards: False # === Optimization === diff --git a/rllib/tuned_examples/ddpg/pendulum-ddpg.yaml b/rllib/tuned_examples/ddpg/pendulum-ddpg.yaml index 62ae1b0c6..54779266d 100644 --- a/rllib/tuned_examples/ddpg/pendulum-ddpg.yaml +++ b/rllib/tuned_examples/ddpg/pendulum-ddpg.yaml @@ -36,7 +36,7 @@ pendulum-ddpg: replay_buffer_config: type: MultiAgentPrioritizedReplayBuffer capacity: 10000 - worker_side_prioritization: False + worker_side_prioritization: false clip_rewards: False # === Optimization === diff --git a/rllib/tuned_examples/sac/atari-sac.yaml b/rllib/tuned_examples/sac/atari-sac.yaml index d9d984c78..80c5bf108 100644 --- a/rllib/tuned_examples/sac/atari-sac.yaml +++ b/rllib/tuned_examples/sac/atari-sac.yaml @@ -14,10 +14,10 @@ atari-sac-tf-and-torch: framework: grid_search: [tf, torch] gamma: 0.99 - Q_model: + q_model_config: hidden_activation: relu hidden_layer_sizes: [512] - policy_model: + policy_model_config: hidden_activation: relu hidden_layer_sizes: [512] # Do hard syncs. diff --git a/rllib/tuned_examples/sac/halfcheetah-pybullet-sac.yaml b/rllib/tuned_examples/sac/halfcheetah-pybullet-sac.yaml index d415b4ca3..03848c903 100644 --- a/rllib/tuned_examples/sac/halfcheetah-pybullet-sac.yaml +++ b/rllib/tuned_examples/sac/halfcheetah-pybullet-sac.yaml @@ -8,10 +8,10 @@ halfcheetah-pybullet-sac: framework: tf horizon: 1000 soft_horizon: false - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256] tau: 0.005 diff --git a/rllib/tuned_examples/sac/halfcheetah-sac.yaml b/rllib/tuned_examples/sac/halfcheetah-sac.yaml index c3f3704f0..3d2678674 100644 --- a/rllib/tuned_examples/sac/halfcheetah-sac.yaml +++ b/rllib/tuned_examples/sac/halfcheetah-sac.yaml @@ -9,10 +9,10 @@ halfcheetah_sac: framework: tf horizon: 1000 soft_horizon: false - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256] tau: 0.005 diff --git a/rllib/tuned_examples/sac/mspacman-sac.yaml b/rllib/tuned_examples/sac/mspacman-sac.yaml index 96f83bb87..6d0189f89 100644 --- a/rllib/tuned_examples/sac/mspacman-sac.yaml +++ b/rllib/tuned_examples/sac/mspacman-sac.yaml @@ -11,10 +11,10 @@ mspacman-sac-tf: # Works for both torch and tf. framework: tf gamma: 0.99 - Q_model: + q_model_config: fcnet_hiddens: [512] fcnet_activation: relu - policy_model: + policy_model_config: fcnet_hiddens: [512] fcnet_activation: relu # Do hard syncs. diff --git a/rllib/tuned_examples/sac/pendulum-sac-fake-gpus.yaml b/rllib/tuned_examples/sac/pendulum-sac-fake-gpus.yaml index 93fd319c8..cec9e3d8d 100644 --- a/rllib/tuned_examples/sac/pendulum-sac-fake-gpus.yaml +++ b/rllib/tuned_examples/sac/pendulum-sac-fake-gpus.yaml @@ -10,10 +10,10 @@ pendulum-sac-fake-gpus: framework: tf horizon: 200 soft_horizon: false - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [ 256, 256 ] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [ 256, 256 ] tau: 0.005 diff --git a/rllib/tuned_examples/sac/pendulum-sac.yaml b/rllib/tuned_examples/sac/pendulum-sac.yaml index 7b804ccbc..84f9ca98b 100644 --- a/rllib/tuned_examples/sac/pendulum-sac.yaml +++ b/rllib/tuned_examples/sac/pendulum-sac.yaml @@ -12,10 +12,10 @@ pendulum-sac: framework: tf horizon: 200 soft_horizon: false - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256] tau: 0.005 diff --git a/rllib/tuned_examples/sac/pendulum-transformed-actions-sac.yaml b/rllib/tuned_examples/sac/pendulum-transformed-actions-sac.yaml index 7d0bfb862..89cac16dd 100644 --- a/rllib/tuned_examples/sac/pendulum-transformed-actions-sac.yaml +++ b/rllib/tuned_examples/sac/pendulum-transformed-actions-sac.yaml @@ -19,10 +19,10 @@ transformed-actions-pendulum-sac-dummy-torch: horizon: 200 soft_horizon: false - Q_model: + q_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256] - policy_model: + policy_model_config: fcnet_activation: relu fcnet_hiddens: [256, 256] tau: 0.005 diff --git a/rllib/utils/replay_buffers/utils.py b/rllib/utils/replay_buffers/utils.py index 68f6fa7c7..3485eed5c 100644 --- a/rllib/utils/replay_buffers/utils.py +++ b/rllib/utils/replay_buffers/utils.py @@ -138,6 +138,13 @@ def validate_buffer_config(config: dict): if config.get("replay_buffer_config", None) is None: config["replay_buffer_config"] = {} + if config.get("worker_side_prioritization", DEPRECATED_VALUE) != DEPRECATED_VALUE: + deprecation_warning( + old="config['worker_side_prioritization']", + new="config['replay_buffer_config']['worker_side_prioritization']", + error=True, + ) + prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE) if prioritized_replay != DEPRECATED_VALUE: deprecation_warning(