2021-02-25 12:18:11 +01:00
|
|
|
|
import logging
|
2022-05-23 12:15:45 +02:00
|
|
|
|
from typing import Optional, Type
|
2021-02-25 12:18:11 +01:00
|
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
|
from ray.rllib.algorithms.dqn import DQN, DQNConfig
|
|
|
|
|
from ray.rllib.algorithms.r2d2.r2d2_tf_policy import R2D2TFPolicy
|
|
|
|
|
from ray.rllib.algorithms.r2d2.r2d2_torch_policy import R2D2TorchPolicy
|
2021-02-25 12:18:11 +01:00
|
|
|
|
from ray.rllib.policy.policy import Policy
|
2021-11-30 18:05:44 +01:00
|
|
|
|
from ray.rllib.utils.annotations import override
|
2022-05-23 12:15:45 +02:00
|
|
|
|
from ray.rllib.utils.deprecation import Deprecated
|
2022-06-11 15:10:39 +02:00
|
|
|
|
from ray.rllib.utils.typing import AlgorithmConfigDict
|
2022-05-10 20:36:14 +02:00
|
|
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
2021-02-25 12:18:11 +01:00
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2022-05-23 12:15:45 +02:00
|
|
|
|
|
|
|
|
|
class R2D2Config(DQNConfig):
|
2022-06-11 15:10:39 +02:00
|
|
|
|
"""Defines a configuration class from which a R2D2 Algorithm can be built.
|
2022-05-23 12:15:45 +02:00
|
|
|
|
|
|
|
|
|
Example:
|
2022-06-04 07:35:24 +02:00
|
|
|
|
>>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
|
2022-05-23 12:15:45 +02:00
|
|
|
|
>>> config = R2D2Config()
|
|
|
|
|
>>> print(config.h_function_epsilon)
|
|
|
|
|
>>> replay_config = config.replay_buffer_config.update(
|
|
|
|
|
>>> {
|
|
|
|
|
>>> "capacity": 1000000,
|
|
|
|
|
>>> "replay_burn_in": 20,
|
|
|
|
|
>>> }
|
|
|
|
|
>>> )
|
|
|
|
|
>>> config.training(replay_buffer_config=replay_config)\
|
|
|
|
|
>>> .resources(num_gpus=1)\
|
|
|
|
|
>>> .rollouts(num_rollout_workers=30)\
|
|
|
|
|
>>> .environment("CartPole-v1")
|
2022-06-20 15:54:00 +02:00
|
|
|
|
>>> algo = R2D2(config=config)
|
2022-05-23 12:15:45 +02:00
|
|
|
|
>>> while True:
|
2022-06-20 15:54:00 +02:00
|
|
|
|
>>> algo.train()
|
2022-05-23 12:15:45 +02:00
|
|
|
|
|
|
|
|
|
Example:
|
2022-06-04 07:35:24 +02:00
|
|
|
|
>>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
|
2022-05-23 12:15:45 +02:00
|
|
|
|
>>> from ray import tune
|
|
|
|
|
>>> config = R2D2Config()
|
|
|
|
|
>>> config.training(train_batch_size=tune.grid_search([256, 64])
|
|
|
|
|
>>> config.environment(env="CartPole-v1")
|
|
|
|
|
>>> tune.run(
|
|
|
|
|
>>> "R2D2",
|
|
|
|
|
>>> stop={"episode_reward_mean":200},
|
|
|
|
|
>>> config=config.to_dict()
|
|
|
|
|
>>> )
|
|
|
|
|
|
|
|
|
|
Example:
|
2022-06-04 07:35:24 +02:00
|
|
|
|
>>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
|
2022-05-23 12:15:45 +02:00
|
|
|
|
>>> config = R2D2Config()
|
|
|
|
|
>>> print(config.exploration_config)
|
|
|
|
|
>>> explore_config = config.exploration_config.update(
|
|
|
|
|
>>> {
|
|
|
|
|
>>> "initial_epsilon": 1.0,
|
|
|
|
|
>>> "final_epsilon": 0.1,
|
|
|
|
|
>>> "epsilone_timesteps": 200000,
|
|
|
|
|
>>> }
|
|
|
|
|
>>> )
|
|
|
|
|
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
|
|
|
|
|
>>> .exploration(exploration_config=explore_config)
|
|
|
|
|
|
|
|
|
|
Example:
|
2022-06-04 07:35:24 +02:00
|
|
|
|
>>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
|
2022-05-23 12:15:45 +02:00
|
|
|
|
>>> config = R2D2Config()
|
|
|
|
|
>>> print(config.exploration_config)
|
|
|
|
|
>>> explore_config = config.exploration_config.update(
|
|
|
|
|
>>> {
|
|
|
|
|
>>> "type": "SoftQ",
|
|
|
|
|
>>> "temperature": [1.0],
|
|
|
|
|
>>> }
|
|
|
|
|
>>> )
|
|
|
|
|
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
|
|
|
|
|
>>> .exploration(exploration_config=explore_config)
|
|
|
|
|
"""
|
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
|
def __init__(self, algo_class=None):
|
2022-05-23 12:15:45 +02:00
|
|
|
|
"""Initializes a ApexConfig instance."""
|
2022-06-11 15:10:39 +02:00
|
|
|
|
super().__init__(algo_class=algo_class or R2D2)
|
2022-05-23 12:15:45 +02:00
|
|
|
|
|
|
|
|
|
# fmt: off
|
|
|
|
|
# __sphinx_doc_begin__
|
|
|
|
|
# R2D2-specific settings:
|
|
|
|
|
self.zero_init_states = True
|
|
|
|
|
self.use_h_function = True
|
|
|
|
|
self.h_function_epsilon = 1e-3
|
|
|
|
|
|
|
|
|
|
# R2D2 settings overriding DQN ones:
|
|
|
|
|
# .training()
|
|
|
|
|
self.adam_epsilon = 1e-3
|
|
|
|
|
self.lr = 1e-4
|
|
|
|
|
self.gamma = 0.997
|
2022-05-24 14:39:43 +02:00
|
|
|
|
self.train_batch_size = 1000
|
2022-06-07 10:29:56 +02:00
|
|
|
|
self.target_network_update_freq = 1000
|
|
|
|
|
self.training_intensity = 150
|
2022-05-23 12:15:45 +02:00
|
|
|
|
# R2D2 is using a buffer that stores sequences.
|
|
|
|
|
self.replay_buffer_config = {
|
2022-03-29 15:44:40 +03:00
|
|
|
|
"type": "MultiAgentReplayBuffer",
|
2022-05-17 13:43:49 +02:00
|
|
|
|
# Specify prioritized replay by supplying a buffer type that supports
|
|
|
|
|
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
|
|
|
|
"prioritized_replay": DEPRECATED_VALUE,
|
2022-04-18 12:20:12 +02:00
|
|
|
|
# Size of the replay buffer (in sequences, not timesteps).
|
|
|
|
|
"capacity": 100000,
|
2022-05-24 14:39:43 +02:00
|
|
|
|
# This algorithm learns on sequences. We therefore require the replay buffer
|
|
|
|
|
# to slice sampled batches into sequences before replay. How sequences
|
|
|
|
|
# are sliced depends on the parameters `replay_sequence_length`,
|
|
|
|
|
# `replay_burn_in`, and `replay_zero_init_states`.
|
2022-05-12 16:12:42 +02:00
|
|
|
|
"storage_unit": "sequences",
|
2022-04-18 12:20:12 +02:00
|
|
|
|
# Set automatically: The number
|
|
|
|
|
# of contiguous environment steps to
|
|
|
|
|
# replay at once. Will be calculated via
|
|
|
|
|
# model->max_seq_len + burn_in.
|
|
|
|
|
# Do not set this to any valid value!
|
|
|
|
|
"replay_sequence_length": -1,
|
2022-05-10 20:36:14 +02:00
|
|
|
|
# If > 0, use the `replay_burn_in` first steps of each replay-sampled
|
|
|
|
|
# sequence (starting either from all 0.0-values if `zero_init_state=True` or
|
|
|
|
|
# from the already stored values) to calculate an even more accurate
|
|
|
|
|
# initial states for the actual sequence (starting after this burn-in
|
|
|
|
|
# window). In the burn-in case, the actual length of the sequence
|
|
|
|
|
# used for loss calculation is `n - replay_burn_in` time steps
|
|
|
|
|
# (n=LSTM’s/attention net’s max_seq_len).
|
|
|
|
|
"replay_burn_in": 0,
|
2022-05-23 12:15:45 +02:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# .rollouts()
|
|
|
|
|
self.num_workers = 2
|
|
|
|
|
self.batch_mode = "complete_episodes"
|
|
|
|
|
|
|
|
|
|
# fmt: on
|
|
|
|
|
# __sphinx_doc_end__
|
|
|
|
|
|
|
|
|
|
self.burn_in = DEPRECATED_VALUE
|
|
|
|
|
|
|
|
|
|
def training(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
zero_init_states: Optional[bool] = None,
|
|
|
|
|
use_h_function: Optional[bool] = None,
|
|
|
|
|
h_function_epsilon: Optional[float] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> "R2D2Config":
|
|
|
|
|
"""Sets the training related configuration.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
zero_init_states: If True, assume a zero-initialized state input (no
|
|
|
|
|
matter where in the episode the sequence is located).
|
|
|
|
|
If False, store the initial states along with each SampleBatch, use
|
|
|
|
|
it (as initial state when running through the network for training),
|
|
|
|
|
and update that initial state during training (from the internal
|
|
|
|
|
state outputs of the immediately preceding sequence).
|
|
|
|
|
use_h_function: Whether to use the h-function from the paper [1] to scale
|
|
|
|
|
target values in the R2D2-loss function:
|
|
|
|
|
h(x) = sign(x)(|x| + 1 − 1) + εx
|
|
|
|
|
h_function_epsilon: The epsilon parameter from the R2D2 loss function (only
|
|
|
|
|
used if `use_h_function`=True.
|
|
|
|
|
|
|
|
|
|
Returns:
|
2022-06-11 15:10:39 +02:00
|
|
|
|
This updated AlgorithmConfig object.
|
2022-05-23 12:15:45 +02:00
|
|
|
|
"""
|
|
|
|
|
# Pass kwargs onto super's `training()` method.
|
|
|
|
|
super().training(**kwargs)
|
|
|
|
|
|
|
|
|
|
if zero_init_states is not None:
|
|
|
|
|
self.zero_init_states = zero_init_states
|
|
|
|
|
if use_h_function is not None:
|
|
|
|
|
self.use_h_function = use_h_function
|
|
|
|
|
if h_function_epsilon is not None:
|
|
|
|
|
self.h_function_epsilon = h_function_epsilon
|
|
|
|
|
|
|
|
|
|
return self
|
2021-02-25 12:18:11 +01:00
|
|
|
|
|
|
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
|
class R2D2(DQN):
|
2021-11-30 18:05:44 +01:00
|
|
|
|
"""Recurrent Experience Replay in Distrib. Reinforcement Learning (R2D2).
|
2021-02-25 12:18:11 +01:00
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
|
Algorithm defining the distributed R2D2 algorithm.
|
2021-11-30 18:05:44 +01:00
|
|
|
|
See `r2d2_[tf|torch]_policy.py` for the definition of the policies.
|
2021-02-25 12:18:11 +01:00
|
|
|
|
|
2021-11-30 18:05:44 +01:00
|
|
|
|
[1] Recurrent Experience Replay in Distributed Reinforcement Learning -
|
|
|
|
|
S Kapturowski, G Ostrovski, J Quan, R Munos, W Dabney - 2019, DeepMind
|
2021-02-25 12:18:11 +01:00
|
|
|
|
|
2021-11-30 18:05:44 +01:00
|
|
|
|
|
|
|
|
|
Detailed documentation:
|
|
|
|
|
https://docs.ray.io/en/master/rllib-algorithms.html#\
|
|
|
|
|
recurrent-replay-distributed-dqn-r2d2
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@classmethod
|
2022-06-04 07:35:24 +02:00
|
|
|
|
@override(DQN)
|
2022-06-11 15:10:39 +02:00
|
|
|
|
def get_default_config(cls) -> AlgorithmConfigDict:
|
2022-05-23 12:15:45 +02:00
|
|
|
|
return R2D2Config().to_dict()
|
2021-11-30 18:05:44 +01:00
|
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
|
@override(DQN)
|
2022-06-11 15:10:39 +02:00
|
|
|
|
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
2021-11-30 18:05:44 +01:00
|
|
|
|
if config["framework"] == "torch":
|
|
|
|
|
return R2D2TorchPolicy
|
|
|
|
|
else:
|
|
|
|
|
return R2D2TFPolicy
|
|
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
|
@override(DQN)
|
2022-06-11 15:10:39 +02:00
|
|
|
|
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
2021-11-30 18:05:44 +01:00
|
|
|
|
"""Checks and updates the config based on settings.
|
|
|
|
|
|
|
|
|
|
Rewrites rollout_fragment_length to take into account burn-in and
|
|
|
|
|
max_seq_len truncation.
|
|
|
|
|
"""
|
2022-01-10 11:19:40 +01:00
|
|
|
|
# Call super's validation method.
|
2021-11-30 18:05:44 +01:00
|
|
|
|
super().validate_config(config)
|
|
|
|
|
|
2022-06-10 16:47:51 +02:00
|
|
|
|
if config["replay_buffer_config"].get("replay_sequence_length", -1) != -1:
|
2021-11-30 18:05:44 +01:00
|
|
|
|
raise ValueError(
|
|
|
|
|
"`replay_sequence_length` is calculated automatically to be "
|
|
|
|
|
"model->max_seq_len + burn_in!"
|
|
|
|
|
)
|
|
|
|
|
# Add the `burn_in` to the Model's max_seq_len.
|
|
|
|
|
# Set the replay sequence length to the max_seq_len of the model.
|
2022-04-18 12:20:12 +02:00
|
|
|
|
config["replay_buffer_config"]["replay_sequence_length"] = (
|
2022-05-10 20:36:14 +02:00
|
|
|
|
config["replay_buffer_config"]["replay_burn_in"]
|
|
|
|
|
+ config["model"]["max_seq_len"]
|
2022-01-29 18:41:57 -08:00
|
|
|
|
)
|
2021-11-30 18:05:44 +01:00
|
|
|
|
|
|
|
|
|
if config.get("batch_mode") != "complete_episodes":
|
|
|
|
|
raise ValueError("`batch_mode` must be 'complete_episodes'!")
|
2022-05-23 12:15:45 +02:00
|
|
|
|
|
|
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
|
# Deprecated: Use ray.rllib.algorithms.r2d2.r2d2.R2D2Config instead!
|
2022-05-23 12:15:45 +02:00
|
|
|
|
class _deprecated_default_config(dict):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(R2D2Config().to_dict())
|
|
|
|
|
|
|
|
|
|
@Deprecated(
|
2022-06-04 07:35:24 +02:00
|
|
|
|
old="ray.rllib.agents.dqn.r2d2::R2D2_DEFAULT_CONFIG",
|
|
|
|
|
new="ray.rllib.algorithms.r2d2.r2d2::R2D2Config(...)",
|
2022-05-23 12:15:45 +02:00
|
|
|
|
error=False,
|
|
|
|
|
)
|
|
|
|
|
def __getitem__(self, item):
|
|
|
|
|
return super().__getitem__(item)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
R2D2_DEFAULT_CONFIG = _deprecated_default_config()
|