ray/rllib/algorithms/sac/rnnsac.py

127 lines
4.4 KiB
Python

from typing import Type, Optional
from ray.rllib.algorithms.sac import (
SACTrainer,
SACConfig,
)
from ray.rllib.algorithms.sac.rnnsac_torch_policy import RNNSACTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
class RNNSACConfig(SACConfig):
"""Defines a configuration class from which an RNNSACTrainer can be built.
Example:
>>> config = RNNSACConfig().training(gamma=0.9, lr=0.01)\
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
"""
def __init__(self, trainer_class=None):
super().__init__(trainer_class=trainer_class or RNNSACTrainer)
# fmt: off
# __sphinx_doc_begin__
self.burn_in = DEPRECATED_VALUE
self.batch_mode = "complete_episodes"
self.zero_init_states = True
self.replay_buffer_config["replay_burn_in"] = 0
# Set automatically: The number of contiguous environment steps to
# replay at once. Will be calculated via
# model->max_seq_len + burn_in.
# Do not set this to any valid value!
self.replay_buffer_config["replay_sequence_length"] = -1
# fmt: on
# __sphinx_doc_end__
@override(SACConfig)
def training(
self,
*,
zero_init_states: Optional[bool] = None,
**kwargs,
) -> "RNNSACConfig":
"""Sets the training related configuration.
Args:
zero_init_states: If True, assume a zero-initialized state input (no matter
where in the episode the sequence is located).
If False, store the initial states along with each SampleBatch, use
it (as initial state when running through the network for training),
and update that initial state during training (from the internal
state outputs of the immediately preceding sequence).
Returns:
This updated TrainerConfig object.
"""
super().training(**kwargs)
if zero_init_states is not None:
self.zero_init_states = zero_init_states
return self
class RNNSACTrainer(SACTrainer):
@classmethod
@override(SACTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return RNNSACConfig().to_dict()
@override(SACTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
# Add the `burn_in` to the Model's max_seq_len.
replay_sequence_length = (
config["replay_buffer_config"]["replay_burn_in"]
+ config["model"]["max_seq_len"]
)
# Check if user tries to set replay_sequence_length (to anything
# other than the proper value)
if config["replay_buffer_config"].get("replay_sequence_length", None) not in [
None,
-1,
replay_sequence_length,
]:
raise ValueError(
"`replay_sequence_length` is calculated automatically to be "
"config['model']['max_seq_len'] + config['burn_in']. Leave "
"config['replay_sequence_length'] blank to avoid this error."
)
# Set the replay sequence length to the max_seq_len of the model.
config["replay_buffer_config"][
"replay_sequence_length"
] = replay_sequence_length
if config["framework"] != "torch":
raise ValueError(
"Only `framework=torch` supported so far for RNNSACTrainer!"
)
@override(SACTrainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
return RNNSACTorchPolicy
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(RNNSACConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.sac.rnnsac.DEFAULT_CONFIG",
new="ray.rllib.algorithms.sac.rnnsac.RNNSACConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()