2022-05-22 18:58:47 +01:00
|
|
|
from typing import Type, Optional
|
2021-07-25 16:04:52 +02:00
|
|
|
|
2022-05-22 18:58:47 +01:00
|
|
|
from ray.rllib.algorithms.sac import (
|
|
|
|
SACTrainer,
|
|
|
|
SACConfig,
|
|
|
|
)
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms.sac.rnnsac_torch_policy import RNNSACTorchPolicy
|
2021-07-25 16:04:52 +02:00
|
|
|
from ray.rllib.policy.policy import Policy
|
2021-12-04 22:05:26 +01:00
|
|
|
from ray.rllib.utils.annotations import override
|
2021-07-25 16:04:52 +02:00
|
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
2022-05-22 18:58:47 +01:00
|
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
|
|
|
|
|
|
|
|
|
|
|
|
class RNNSACConfig(SACConfig):
|
|
|
|
"""Defines a configuration class from which an RNNSACTrainer can be built.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> config = RNNSACConfig().training(gamma=0.9, lr=0.01)\
|
|
|
|
... .resources(num_gpus=0)\
|
|
|
|
... .rollouts(num_rollout_workers=4)
|
|
|
|
>>> print(config.to_dict())
|
|
|
|
>>> # Build a Trainer object from the config and run 1 training iteration.
|
|
|
|
>>> trainer = config.build(env="CartPole-v1")
|
|
|
|
>>> trainer.train()
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, trainer_class=None):
|
|
|
|
super().__init__(trainer_class=trainer_class or RNNSACTrainer)
|
|
|
|
# fmt: off
|
|
|
|
# __sphinx_doc_begin__
|
|
|
|
self.burn_in = DEPRECATED_VALUE
|
|
|
|
self.batch_mode = "complete_episodes"
|
|
|
|
self.zero_init_states = True
|
|
|
|
self.replay_buffer_config["replay_burn_in"] = 0
|
|
|
|
# Set automatically: The number of contiguous environment steps to
|
|
|
|
# replay at once. Will be calculated via
|
|
|
|
# model->max_seq_len + burn_in.
|
|
|
|
# Do not set this to any valid value!
|
|
|
|
self.replay_buffer_config["replay_sequence_length"] = -1
|
|
|
|
|
|
|
|
# fmt: on
|
|
|
|
# __sphinx_doc_end__
|
|
|
|
|
|
|
|
@override(SACConfig)
|
|
|
|
def training(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
zero_init_states: Optional[bool] = None,
|
|
|
|
**kwargs,
|
|
|
|
) -> "RNNSACConfig":
|
|
|
|
"""Sets the training related configuration.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
zero_init_states: If True, assume a zero-initialized state input (no matter
|
|
|
|
where in the episode the sequence is located).
|
|
|
|
If False, store the initial states along with each SampleBatch, use
|
|
|
|
it (as initial state when running through the network for training),
|
|
|
|
and update that initial state during training (from the internal
|
|
|
|
state outputs of the immediately preceding sequence).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
This updated TrainerConfig object.
|
|
|
|
"""
|
|
|
|
super().training(**kwargs)
|
|
|
|
if zero_init_states is not None:
|
|
|
|
self.zero_init_states = zero_init_states
|
|
|
|
|
|
|
|
return self
|
2021-07-25 16:04:52 +02:00
|
|
|
|
|
|
|
|
2021-12-04 22:05:26 +01:00
|
|
|
class RNNSACTrainer(SACTrainer):
|
|
|
|
@classmethod
|
|
|
|
@override(SACTrainer)
|
|
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
2022-05-22 18:58:47 +01:00
|
|
|
return RNNSACConfig().to_dict()
|
2021-07-25 16:04:52 +02:00
|
|
|
|
2021-12-04 22:05:26 +01:00
|
|
|
@override(SACTrainer)
|
|
|
|
def validate_config(self, config: TrainerConfigDict) -> None:
|
2022-01-10 11:19:40 +01:00
|
|
|
# Call super's validation method.
|
2021-12-04 22:05:26 +01:00
|
|
|
super().validate_config(config)
|
2021-07-25 16:04:52 +02:00
|
|
|
|
2022-04-25 09:19:24 +02:00
|
|
|
# Add the `burn_in` to the Model's max_seq_len.
|
|
|
|
replay_sequence_length = (
|
|
|
|
config["replay_buffer_config"]["replay_burn_in"]
|
|
|
|
+ config["model"]["max_seq_len"]
|
|
|
|
)
|
|
|
|
# Check if user tries to set replay_sequence_length (to anything
|
|
|
|
# other than the proper value)
|
2022-05-17 13:43:49 +02:00
|
|
|
if config["replay_buffer_config"].get("replay_sequence_length", None) not in [
|
|
|
|
None,
|
2022-04-25 09:19:24 +02:00
|
|
|
-1,
|
|
|
|
replay_sequence_length,
|
|
|
|
]:
|
2021-12-04 22:05:26 +01:00
|
|
|
raise ValueError(
|
|
|
|
"`replay_sequence_length` is calculated automatically to be "
|
2022-04-25 09:19:24 +02:00
|
|
|
"config['model']['max_seq_len'] + config['burn_in']. Leave "
|
|
|
|
"config['replay_sequence_length'] blank to avoid this error."
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-12-04 22:05:26 +01:00
|
|
|
# Set the replay sequence length to the max_seq_len of the model.
|
2022-04-25 09:19:24 +02:00
|
|
|
config["replay_buffer_config"][
|
|
|
|
"replay_sequence_length"
|
|
|
|
] = replay_sequence_length
|
2021-07-25 16:04:52 +02:00
|
|
|
|
2021-12-04 22:05:26 +01:00
|
|
|
if config["framework"] != "torch":
|
|
|
|
raise ValueError(
|
2022-01-29 18:41:57 -08:00
|
|
|
"Only `framework=torch` supported so far for RNNSACTrainer!"
|
|
|
|
)
|
2021-07-25 16:04:52 +02:00
|
|
|
|
2021-12-04 22:05:26 +01:00
|
|
|
@override(SACTrainer)
|
2022-01-29 18:41:57 -08:00
|
|
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
2021-07-25 16:04:52 +02:00
|
|
|
return RNNSACTorchPolicy
|
2022-05-22 18:58:47 +01:00
|
|
|
|
|
|
|
|
|
|
|
class _deprecated_default_config(dict):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__(RNNSACConfig().to_dict())
|
|
|
|
|
|
|
|
@Deprecated(
|
|
|
|
old="ray.rllib.algorithms.sac.rnnsac.DEFAULT_CONFIG",
|
|
|
|
new="ray.rllib.algorithms.sac.rnnsac.RNNSACConfig(...)",
|
|
|
|
error=False,
|
|
|
|
)
|
|
|
|
def __getitem__(self, item):
|
|
|
|
return super().__getitem__(item)
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_CONFIG = _deprecated_default_config()
|