mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] SAC, RNNSAC, and CQL TrainerConfig objects (#25059)
This commit is contained in:
parent
44773e810b
commit
501d932449
42 changed files with 662 additions and 411 deletions
|
@ -438,7 +438,7 @@ Soft Actor Critic (SAC)
|
|||
SAC architecture (same as DQN)
|
||||
|
||||
RLlib's soft-actor critic implementation is ported from the `official SAC repo <https://github.com/rail-berkeley/softlearning>`__ to better integrate with RLlib APIs.
|
||||
Note that SAC has two fields to configure for custom models: ``policy_model`` and ``Q_model``, the ``model`` field of the config will be ignored.
|
||||
Note that SAC has two fields to configure for custom models: ``policy_model_config`` and ``q_model_config``, the ``model`` field of the config will be ignored.
|
||||
|
||||
Tuned examples (continuous actions):
|
||||
`Pendulum-v1 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/sac/pendulum-sac.yaml>`__,
|
||||
|
|
|
@ -106,7 +106,7 @@ This `runnable example <https://github.com/ray-project/ray/blob/master/rllib/exa
|
|||
On-policy algorithms and experience postprocessing
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
RLlib assumes that input batches are of `postprocessed experiences <https://github.com/ray-project/ray/blob/cf21c634a390745ba6f8916b1f34f7b0453bc7dd/rllib/policy/policy.py#L376>`__. This isn't typically critical for off-policy algorithms (e.g., DQN's `post-processing <https://github.com/ray-project/ray/blob/cf21c634a390745ba6f8916b1f34f7b0453bc7dd/rllib/agents/dqn/dqn_tf_policy.py#L387>`__ is only needed if ``n_step > 1`` or ``worker_side_prioritization: True``). For off-policy algorithms, you can also safely set the ``postprocess_inputs: True`` config to auto-postprocess data.
|
||||
RLlib assumes that input batches are of `postprocessed experiences <https://github.com/ray-project/ray/blob/cf21c634a390745ba6f8916b1f34f7b0453bc7dd/rllib/policy/policy.py#L376>`__. This isn't typically critical for off-policy algorithms (e.g., DQN's `post-processing <https://github.com/ray-project/ray/blob/cf21c634a390745ba6f8916b1f34f7b0453bc7dd/rllib/agents/dqn/dqn_tf_policy.py#L387>`__ is only needed if ``n_step > 1`` or ``replay_buffer_config.worker_side_prioritization: True``). For off-policy algorithms, you can also safely set the ``postprocess_inputs: True`` config to auto-postprocess data.
|
||||
|
||||
However, for on-policy algorithms like PPO, you'll need to pass in the extra values added during policy evaluation and postprocessing to ``batch_builder.add_values()``, e.g., ``logits``, ``vf_preds``, ``value_target``, and ``advantages`` for PPO. This is needed since the calculation of these values depends on the parameters of the *behaviour* policy, which RLlib does not have access to in the offline setting (in online training, these values are automatically added during policy evaluation).
|
||||
|
||||
|
|
|
@ -13,10 +13,10 @@ cql-halfcheetahbulletenv-v0:
|
|||
|
||||
soft_horizon: False
|
||||
horizon: 1000
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
tau: 0.005
|
||||
|
|
|
@ -10,10 +10,10 @@ sac-halfcheetahbulletenv-v0:
|
|||
config:
|
||||
horizon: 1000
|
||||
soft_horizon: false
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
tau: 0.005
|
||||
|
|
|
@ -80,7 +80,7 @@ def _import_bc():
|
|||
def _import_cql():
|
||||
from ray.rllib.algorithms import cql
|
||||
|
||||
return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG
|
||||
return cql.CQLTrainer, cql.DEFAULT_CONFIG
|
||||
|
||||
|
||||
def _import_ddpg():
|
||||
|
|
|
@ -849,7 +849,7 @@ class TrainerConfig:
|
|||
self.evaluation_num_workers = evaluation_num_workers
|
||||
if custom_evaluation_function is not None:
|
||||
self.custom_evaluation_function = custom_evaluation_function
|
||||
if self.always_attach_evaluation_results:
|
||||
if always_attach_evaluation_results:
|
||||
self.always_attach_evaluation_results = always_attach_evaluation_results
|
||||
|
||||
return self
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.cql.cql import CQLTrainer, DEFAULT_CONFIG, CQLConfig
|
||||
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"CQL_DEFAULT_CONFIG",
|
||||
"DEFAULT_CONFIG",
|
||||
"CQLTorchPolicy",
|
||||
"CQLTrainer",
|
||||
"CQLConfig",
|
||||
]
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
|
||||
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
||||
from ray.rllib.algorithms.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
|
||||
from ray.rllib.algorithms.sac.sac import (
|
||||
SACTrainer,
|
||||
SACConfig,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
multi_gpu_train_one_step,
|
||||
train_one_step,
|
||||
|
@ -12,9 +15,12 @@ from ray.rllib.execution.train_ops import (
|
|||
from ray.rllib.offline.shuffled_input import ShuffledInput
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.rllib.utils.deprecation import (
|
||||
DEPRECATED_VALUE,
|
||||
deprecation_warning,
|
||||
Deprecated,
|
||||
)
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
||||
from ray.rllib.utils.metrics import (
|
||||
LAST_TARGET_UPDATE_TS,
|
||||
|
@ -31,38 +37,86 @@ tf1, tf, tfv = try_import_tf()
|
|||
tfp = try_import_tfp()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
CQL_DEFAULT_CONFIG = merge_dicts(
|
||||
SAC_CONFIG, {
|
||||
# You should override this to point to an offline dataset.
|
||||
"input": "sampler",
|
||||
# Switch off off-policy evaluation.
|
||||
"input_evaluation": [],
|
||||
# Number of iterations with Behavior Cloning Pretraining.
|
||||
"bc_iters": 20000,
|
||||
# CQL loss temperature.
|
||||
"temperature": 1.0,
|
||||
# Number of actions to sample for CQL loss.
|
||||
"num_actions": 10,
|
||||
# Whether to use the Lagrangian for Alpha Prime (in CQL loss).
|
||||
"lagrangian": False,
|
||||
# Lagrangian threshold.
|
||||
"lagrangian_thresh": 5.0,
|
||||
# Min Q weight multiplier.
|
||||
"min_q_weight": 5.0,
|
||||
# Reporting: As CQL is offline (no sampling steps), we need to limit
|
||||
# `self.train()` reporting by the number of steps trained (not sampled).
|
||||
"min_sample_timesteps_per_reporting": 0,
|
||||
"min_train_timesteps_per_reporting": 100,
|
||||
|
||||
# Deprecated keys.
|
||||
# Use `min_sample_timesteps_per_reporting` and
|
||||
# `min_train_timesteps_per_reporting` instead.
|
||||
"timesteps_per_iteration": DEPRECATED_VALUE,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
class CQLConfig(SACConfig):
|
||||
"""Defines a configuration class from which a CQLTrainer can be built.
|
||||
|
||||
Example:
|
||||
>>> config = CQLConfig().training(gamma=0.9, lr=0.01)\
|
||||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
super().__init__(trainer_class=trainer_class or CQLTrainer)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# CQL-specific config settings:
|
||||
self.bc_iters = 20000
|
||||
self.temperature = 1.0
|
||||
self.num_actions = 10
|
||||
self.lagrangian = False
|
||||
self.lagrangian_thresh = 5.0
|
||||
self.min_q_weight = 5.0
|
||||
|
||||
# Changes to Trainer's/SACConfig's default:
|
||||
# .offline_data()
|
||||
self.input_evaluation = []
|
||||
|
||||
# .reporting()
|
||||
self.min_sample_timesteps_per_reporting = 0
|
||||
self.min_train_timesteps_per_reporting = 100
|
||||
# fmt: on
|
||||
# __sphinx_doc_end__
|
||||
|
||||
self.timesteps_per_iteration = DEPRECATED_VALUE
|
||||
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
bc_iters: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
num_actions: Optional[int] = None,
|
||||
lagrangian: Optional[bool] = None,
|
||||
lagrangian_thresh: Optional[float] = None,
|
||||
min_q_weight: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> "CQLConfig":
|
||||
"""Sets the training-related configuration.
|
||||
|
||||
Args:
|
||||
bc_iters: Number of iterations with Behavior Cloning pretraining.
|
||||
temperature: CQL loss temperature.
|
||||
num_actions: Number of actions to sample for CQL loss
|
||||
lagrangian: Whether to use the Lagrangian for Alpha Prime (in CQL loss).
|
||||
lagrangian_thresh: Lagrangian threshold.
|
||||
min_q_weight: in Q weight multiplier.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
||||
if bc_iters is not None:
|
||||
self.bc_iters = bc_iters
|
||||
if temperature is not None:
|
||||
self.temperature = temperature
|
||||
if num_actions is not None:
|
||||
self.num_actions = num_actions
|
||||
if lagrangian is not None:
|
||||
self.lagrangian = lagrangian
|
||||
if lagrangian_thresh is not None:
|
||||
self.lagrangian_thresh = lagrangian_thresh
|
||||
if min_q_weight is not None:
|
||||
self.min_q_weight = min_q_weight
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class CQLTrainer(SACTrainer):
|
||||
|
@ -111,7 +165,7 @@ class CQLTrainer(SACTrainer):
|
|||
@classmethod
|
||||
@override(SACTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return CQL_DEFAULT_CONFIG
|
||||
return CQLConfig().to_dict()
|
||||
|
||||
@override(SACTrainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
|
@ -208,3 +262,20 @@ class CQLTrainer(SACTrainer):
|
|||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
||||
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(CQLConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.algorithms.cql.cql.DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.cql.cql.CQLConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
CQL_DEFAULT_CONFIG = DEFAULT_CONFIG
|
||||
|
|
|
@ -411,7 +411,7 @@ def apply_gradients_fn(policy, optimizer, grads_and_vars):
|
|||
CQLTFPolicy = build_tf_policy(
|
||||
name="CQLTFPolicy",
|
||||
loss_fn=cql_loss,
|
||||
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.cql.cql.DEFAULT_CONFIG,
|
||||
validate_spaces=validate_spaces,
|
||||
stats_fn=cql_stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
|
|
|
@ -390,7 +390,7 @@ CQLTorchPolicy = build_policy_class(
|
|||
name="CQLTorchPolicy",
|
||||
framework="torch",
|
||||
loss_fn=cql_loss,
|
||||
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.cql.cql.DEFAULT_CONFIG,
|
||||
stats_fn=cql_stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
|
|
|
@ -4,7 +4,7 @@ import os
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.algorithms.cql as cql
|
||||
from ray.rllib.algorithms import cql
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import (
|
||||
check_compute_single_action,
|
||||
|
@ -38,36 +38,43 @@ class TestCQL(unittest.TestCase):
|
|||
data_file = os.path.join(rllib_dir, "tests/data/pendulum/small.json")
|
||||
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
||||
|
||||
config = {
|
||||
"env": "Pendulum-v1",
|
||||
"input": [data_file],
|
||||
# In the files, we use here for testing, actions have already
|
||||
# been normalized.
|
||||
# This is usually the case when the file was generated by another
|
||||
# RLlib algorithm (e.g. PPO or SAC).
|
||||
"actions_in_input_normalized": False,
|
||||
"clip_actions": True,
|
||||
"train_batch_size": 2000,
|
||||
"twin_q": True,
|
||||
"replay_buffer_config": {"learning_starts": 0},
|
||||
"bc_iters": 2, # 2 BC iters, 2 CQL iters.
|
||||
"rollout_fragment_length": 1,
|
||||
# Switch on off-policy evaluation.
|
||||
"input_evaluation": ["is"],
|
||||
"always_attach_evaluation_results": True,
|
||||
"evaluation_interval": 2,
|
||||
"evaluation_duration": 10,
|
||||
"evaluation_config": {
|
||||
"input": "sampler",
|
||||
},
|
||||
"evaluation_parallel_to_training": False,
|
||||
"evaluation_num_workers": 2,
|
||||
}
|
||||
config = (
|
||||
cql.CQLConfig()
|
||||
.environment(
|
||||
env="Pendulum-v1",
|
||||
)
|
||||
.offline_data(
|
||||
input_=[data_file],
|
||||
# In the files, we use here for testing, actions have already
|
||||
# been normalized.
|
||||
# This is usually the case when the file was generated by another
|
||||
# RLlib algorithm (e.g. PPO or SAC).
|
||||
actions_in_input_normalized=False,
|
||||
# Switch on off-policy evaluation.
|
||||
input_evaluation=["is"],
|
||||
)
|
||||
.training(
|
||||
clip_actions=False,
|
||||
train_batch_size=2000,
|
||||
twin_q=True,
|
||||
replay_buffer_config={"learning_starts": 0},
|
||||
bc_iters=2,
|
||||
)
|
||||
.evaluation(
|
||||
always_attach_evaluation_results=True,
|
||||
evaluation_interval=2,
|
||||
evaluation_duration=10,
|
||||
evaluation_config={"input": "sampler"},
|
||||
evaluation_parallel_to_training=False,
|
||||
evaluation_num_workers=2,
|
||||
)
|
||||
.rollouts(rollout_fragment_length=1)
|
||||
)
|
||||
num_iterations = 4
|
||||
|
||||
# Test for tf/torch frameworks.
|
||||
for fw in framework_iterator(config, with_eager_tracing=True):
|
||||
trainer = cql.CQLTrainer(config=config)
|
||||
trainer = config.build()
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
check_train_results(results)
|
||||
|
|
|
@ -124,9 +124,11 @@ class DDPGConfig(SimpleQConfig):
|
|||
|
||||
# .rollouts()
|
||||
self.rollout_fragment_length = 1
|
||||
self.worker_side_prioritization = False
|
||||
self.compress_observations = False
|
||||
|
||||
# Deprecated.
|
||||
self.worker_side_prioritization = DEPRECATED_VALUE
|
||||
|
||||
@override(TrainerConfig)
|
||||
def training(
|
||||
self,
|
||||
|
@ -148,7 +150,6 @@ class DDPGConfig(SimpleQConfig):
|
|||
use_huber: Optional[bool] = None,
|
||||
huber_threshold: Optional[float] = None,
|
||||
l2_reg: Optional[float] = None,
|
||||
worker_side_prioritization: Optional[bool] = None,
|
||||
training_intensity: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> "DDPGConfig":
|
||||
|
@ -190,7 +191,6 @@ class DDPGConfig(SimpleQConfig):
|
|||
use_huber: Conventionally, no need to clip gradients if using a huber loss
|
||||
huber_threshold: Threshold of a huber loss
|
||||
l2_reg: Weights for L2 regularization
|
||||
worker_side_prioritization: Whether to compute priorities on workers.
|
||||
training_intensity: The intensity with which to update the model
|
||||
(vs collecting samples from
|
||||
the env). If None, uses the "natural" value of:
|
||||
|
@ -246,8 +246,6 @@ class DDPGConfig(SimpleQConfig):
|
|||
self.huber_threshold = huber_threshold
|
||||
if l2_reg is not None:
|
||||
self.l2_reg = l2_reg
|
||||
if worker_side_prioritization is not None:
|
||||
self.worker_side_prioritization = worker_side_prioritization
|
||||
if training_intensity is not None:
|
||||
self.training_intensity = training_intensity
|
||||
|
||||
|
|
|
@ -135,7 +135,7 @@ class DQNConfig(SimpleQConfig):
|
|||
self.before_learn_on_batch = None
|
||||
self.training_intensity = None
|
||||
|
||||
# Changes to SimpleQConfig default
|
||||
# Changes to SimpleQConfig's default:
|
||||
self.replay_buffer_config = {
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
|
@ -176,7 +176,6 @@ class DQNConfig(SimpleQConfig):
|
|||
Type[MultiAgentBatch],
|
||||
] = None,
|
||||
training_intensity: Optional[float] = None,
|
||||
worker_side_prioritization: Optional[bool] = None,
|
||||
replay_buffer_config: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> "DQNConfig":
|
||||
|
@ -213,7 +212,6 @@ class DQNConfig(SimpleQConfig):
|
|||
-> will make sure that replay+train op will be executed 4x asoften as
|
||||
rollout+insert op (4 * 250 = 1000).
|
||||
See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
|
||||
worker_side_prioritization: Whether to compute priorities on workers.
|
||||
replay_buffer_config: Replay buffer config.
|
||||
Examples:
|
||||
{
|
||||
|
@ -278,11 +276,11 @@ class DQNConfig(SimpleQConfig):
|
|||
self.before_learn_on_batch = before_learn_on_batch
|
||||
if training_intensity is not None:
|
||||
self.training_intensity = training_intensity
|
||||
if worker_side_prioritization is not None:
|
||||
self.worker_side_priorizatiion = worker_side_prioritization
|
||||
if replay_buffer_config is not None:
|
||||
self.replay_buffer_config = replay_buffer_config
|
||||
|
||||
return self
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
|
|
|
@ -71,7 +71,6 @@ class QMixConfig(SimpleQConfig):
|
|||
self.optim_alpha = 0.99
|
||||
self.optim_eps = 0.00001
|
||||
self.grad_norm_clipping = 10
|
||||
self.worker_side_prioritization = False
|
||||
|
||||
# Override some of TrainerConfig's default values with QMix-specific values.
|
||||
# .training()
|
||||
|
@ -136,6 +135,8 @@ class QMixConfig(SimpleQConfig):
|
|||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
self.worker_side_prioritization = DEPRECATED_VALUE
|
||||
|
||||
@override(SimpleQConfig)
|
||||
def training(
|
||||
self,
|
||||
|
@ -148,7 +149,6 @@ class QMixConfig(SimpleQConfig):
|
|||
optim_alpha: Optional[float] = None,
|
||||
optim_eps: Optional[float] = None,
|
||||
grad_norm_clipping: Optional[float] = None,
|
||||
worker_side_prioritization: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> "QMixConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
@ -164,8 +164,6 @@ class QMixConfig(SimpleQConfig):
|
|||
optim_eps: RMSProp epsilon.
|
||||
grad_norm_clipping: If not None, clip gradients during optimization at
|
||||
this value.
|
||||
worker_side_prioritization: Whether to compute priorities for the replay
|
||||
buffer on worker side.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
|
@ -189,8 +187,6 @@ class QMixConfig(SimpleQConfig):
|
|||
self.optim_eps = optim_eps
|
||||
if grad_norm_clipping is not None:
|
||||
self.grad_norm_clipping = grad_norm_clipping
|
||||
if worker_side_prioritization is not None:
|
||||
self.worker_side_prioritization = worker_side_prioritization
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from ray.rllib.algorithms.sac.sac import SACTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.sac.sac import SACTrainer, DEFAULT_CONFIG, SACConfig
|
||||
from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy
|
||||
from ray.rllib.algorithms.sac.sac_torch_policy import SACTorchPolicy
|
||||
|
||||
|
@ -6,14 +6,16 @@ from ray.rllib.algorithms.sac.rnnsac import (
|
|||
RNNSACTrainer,
|
||||
DEFAULT_CONFIG as RNNSAC_DEFAULT_CONFIG,
|
||||
)
|
||||
from ray.rllib.algorithms.sac.rnnsac import RNNSACTorchPolicy
|
||||
from ray.rllib.algorithms.sac.rnnsac import RNNSACTorchPolicy, RNNSACConfig
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CONFIG",
|
||||
"SACTFPolicy",
|
||||
"SACTorchPolicy",
|
||||
"SACTrainer",
|
||||
"SACConfig",
|
||||
"RNNSAC_DEFAULT_CONFIG",
|
||||
"RNNSACTorchPolicy",
|
||||
"RNNSACTrainer",
|
||||
"RNNSACConfig",
|
||||
]
|
||||
|
|
|
@ -1,50 +1,78 @@
|
|||
from typing import Type
|
||||
from typing import Type, Optional
|
||||
|
||||
from ray.rllib.algorithms.sac import SACTrainer, DEFAULT_CONFIG as SAC_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.sac import (
|
||||
SACTrainer,
|
||||
SACConfig,
|
||||
)
|
||||
from ray.rllib.algorithms.sac.rnnsac_torch_policy import RNNSACTorchPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
|
||||
|
||||
DEFAULT_CONFIG = SACTrainer.merge_trainer_configs(
|
||||
SAC_DEFAULT_CONFIG,
|
||||
{
|
||||
# Batch mode (see common config)
|
||||
"batch_mode": "complete_episodes",
|
||||
# If True, assume a zero-initialized state input (no matter where in
|
||||
# the episode the sequence is located).
|
||||
# If False, store the initial states along with each SampleBatch, use
|
||||
# it (as initial state when running through the network for training),
|
||||
# and update that initial state during training (from the internal
|
||||
# state outputs of the immediately preceding sequence).
|
||||
"zero_init_states": True,
|
||||
"replay_buffer_config": {
|
||||
# If > 0, use the `burn_in` first steps of each replay-sampled sequence
|
||||
# (starting either from all 0.0-values if `zero_init_state=True` or
|
||||
# from the already stored values) to calculate an even more accurate
|
||||
# initial states for the actual sequence (starting after this burn-in
|
||||
# window). In the burn-in case, the actual length of the sequence
|
||||
# used for loss calculation is `n - burn_in` time steps
|
||||
# (n=LSTM’s/attention net’s max_seq_len).
|
||||
"replay_burn_in": 0,
|
||||
# Set automatically: The number of contiguous environment steps to
|
||||
# replay at once. Will be calculated via
|
||||
# model->max_seq_len + burn_in.
|
||||
# Do not set this to any valid value!
|
||||
"replay_sequence_length": -1,
|
||||
},
|
||||
"burn_in": DEPRECATED_VALUE,
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
|
||||
class RNNSACConfig(SACConfig):
|
||||
"""Defines a configuration class from which an RNNSACTrainer can be built.
|
||||
|
||||
Example:
|
||||
>>> config = RNNSACConfig().training(gamma=0.9, lr=0.01)\
|
||||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
super().__init__(trainer_class=trainer_class or RNNSACTrainer)
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
self.burn_in = DEPRECATED_VALUE
|
||||
self.batch_mode = "complete_episodes"
|
||||
self.zero_init_states = True
|
||||
self.replay_buffer_config["replay_burn_in"] = 0
|
||||
# Set automatically: The number of contiguous environment steps to
|
||||
# replay at once. Will be calculated via
|
||||
# model->max_seq_len + burn_in.
|
||||
# Do not set this to any valid value!
|
||||
self.replay_buffer_config["replay_sequence_length"] = -1
|
||||
|
||||
# fmt: on
|
||||
# __sphinx_doc_end__
|
||||
|
||||
@override(SACConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
zero_init_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> "RNNSACConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
||||
Args:
|
||||
zero_init_states: If True, assume a zero-initialized state input (no matter
|
||||
where in the episode the sequence is located).
|
||||
If False, store the initial states along with each SampleBatch, use
|
||||
it (as initial state when running through the network for training),
|
||||
and update that initial state during training (from the internal
|
||||
state outputs of the immediately preceding sequence).
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
"""
|
||||
super().training(**kwargs)
|
||||
if zero_init_states is not None:
|
||||
self.zero_init_states = zero_init_states
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class RNNSACTrainer(SACTrainer):
|
||||
@classmethod
|
||||
@override(SACTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
return RNNSACConfig().to_dict()
|
||||
|
||||
@override(SACTrainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
|
@ -81,3 +109,19 @@ class RNNSACTrainer(SACTrainer):
|
|||
@override(SACTrainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
return RNNSACTorchPolicy
|
||||
|
||||
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(RNNSACConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.algorithms.sac.rnnsac.DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.sac.rnnsac.RNNSACConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -45,12 +45,12 @@ def build_rnnsac_model(
|
|||
num_outputs = int(np.product(obs_space.shape))
|
||||
|
||||
# Force-ignore any additionally provided hidden layer sizes.
|
||||
# Everything should be configured using SAC's "Q_model" and "policy_model"
|
||||
# settings.
|
||||
# Everything should be configured using SAC's `q_model_config` and
|
||||
# `policy_model_config` config settings.
|
||||
policy_model_config = MODEL_DEFAULTS.copy()
|
||||
policy_model_config.update(config["policy_model"])
|
||||
policy_model_config.update(config["policy_model_config"])
|
||||
q_model_config = MODEL_DEFAULTS.copy()
|
||||
q_model_config.update(config["Q_model"])
|
||||
q_model_config.update(config["q_model_config"])
|
||||
|
||||
default_model_cls = RNNSACTorchModel
|
||||
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
import logging
|
||||
from typing import Type
|
||||
from typing import Type, Dict, Any, Optional, Union
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.algorithms.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.utils.deprecation import (
|
||||
DEPRECATED_VALUE,
|
||||
deprecation_warning,
|
||||
Deprecated,
|
||||
)
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
|
@ -15,160 +19,255 @@ tfp = try_import_tfp()
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the (base) `Trainer` config in
|
||||
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Model ===
|
||||
# Use two Q-networks (instead of one) for action-value estimation.
|
||||
# Note: Each Q-network will have its own target network.
|
||||
"twin_q": True,
|
||||
# Use a e.g. conv2D state preprocessing network before concatenating the
|
||||
# resulting (feature) vector with the action input for the input to
|
||||
# the Q-networks.
|
||||
"use_state_preprocessor": DEPRECATED_VALUE,
|
||||
# Model options for the Q network(s). These will override MODEL_DEFAULTS.
|
||||
# The `Q_model` dict is treated just as the top-level `model` dict in
|
||||
# setting up the Q-network(s) (2 if twin_q=True).
|
||||
# That means, you can do for different observation spaces:
|
||||
# obs=Box(1D) -> Tuple(Box(1D) + Action) -> concat -> post_fcnet
|
||||
# obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action
|
||||
# -> post_fcnet
|
||||
# obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action)
|
||||
# -> vision-net -> concat w/ Box(1D) and action -> post_fcnet
|
||||
# You can also have SAC use your custom_model as Q-model(s), by simply
|
||||
# specifying the `custom_model` sub-key in below dict (just like you would
|
||||
# do in the top-level `model` dict.
|
||||
"Q_model": {
|
||||
"fcnet_hiddens": [256, 256],
|
||||
"fcnet_activation": "relu",
|
||||
"post_fcnet_hiddens": [],
|
||||
"post_fcnet_activation": None,
|
||||
"custom_model": None, # Use this to define custom Q-model(s).
|
||||
"custom_model_config": {},
|
||||
},
|
||||
# Model options for the policy function (see `Q_model` above for details).
|
||||
# The difference to `Q_model` above is that no action concat'ing is
|
||||
# performed before the post_fcnet stack.
|
||||
"policy_model": {
|
||||
"fcnet_hiddens": [256, 256],
|
||||
"fcnet_activation": "relu",
|
||||
"post_fcnet_hiddens": [],
|
||||
"post_fcnet_activation": None,
|
||||
"custom_model": None, # Use this to define a custom policy model.
|
||||
"custom_model_config": {},
|
||||
},
|
||||
# Actions are already normalized, no need to clip them further.
|
||||
"clip_actions": False,
|
||||
class SACConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which an SACTrainer can be built.
|
||||
|
||||
# === Learning ===
|
||||
# Update the target by \tau * policy + (1-\tau) * target_policy.
|
||||
"tau": 5e-3,
|
||||
# Initial value to use for the entropy weight alpha.
|
||||
"initial_alpha": 1.0,
|
||||
# Target entropy lower bound. If "auto", will be set to -|A| (e.g. -2.0 for
|
||||
# Discrete(2), -3.0 for Box(shape=(3,))).
|
||||
# This is the inverse of reward scale, and will be optimized automatically.
|
||||
"target_entropy": "auto",
|
||||
# N-step target updates. If >1, sars' tuples in trajectories will be
|
||||
# postprocessed to become sa[discounted sum of R][s t+n] tuples.
|
||||
"n_step": 1,
|
||||
# Minimum env sampling timesteps to accumulate within a single `train()` call. This
|
||||
# value does not affect learning, only the number of times `Trainer.step_attempt()`
|
||||
# is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling
|
||||
# timestep count has not been reached, will perform n more `step_attempt()` calls
|
||||
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
Example:
|
||||
>>> config = SACConfig().training(gamma=0.9, lr=0.01)\
|
||||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
# === Replay buffer ===
|
||||
"replay_buffer_config": {
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"capacity": int(1e6),
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1500,
|
||||
# The number of continuous environment steps to replay at once. This may
|
||||
# be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
},
|
||||
# Set this to True, if you want the contents of your buffer(s) to be
|
||||
# stored in any saved checkpoints as well.
|
||||
# Warnings will be created if:
|
||||
# - This is True AND restoring from a checkpoint that contains no buffer
|
||||
# data.
|
||||
# - This is False AND restoring from a checkpoint that does contain
|
||||
# buffer data.
|
||||
"store_buffer_in_checkpoints": False,
|
||||
# Whether to LZ4 compress observations
|
||||
"compress_observations": False,
|
||||
def __init__(self, trainer_class=None):
|
||||
super().__init__(trainer_class=trainer_class or SACTrainer)
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# SAC-specific config settings.
|
||||
self.twin_q = True
|
||||
self.q_model_config = {
|
||||
"fcnet_hiddens": [256, 256],
|
||||
"fcnet_activation": "relu",
|
||||
"post_fcnet_hiddens": [],
|
||||
"post_fcnet_activation": None,
|
||||
"custom_model": None, # Use this to define custom Q-model(s).
|
||||
"custom_model_config": {},
|
||||
}
|
||||
self.policy_model_config = {
|
||||
"fcnet_hiddens": [256, 256],
|
||||
"fcnet_activation": "relu",
|
||||
"post_fcnet_hiddens": [],
|
||||
"post_fcnet_activation": None,
|
||||
"custom_model": None, # Use this to define a custom policy model.
|
||||
"custom_model_config": {},
|
||||
}
|
||||
self.clip_actions = False
|
||||
self.tau = 5e-3
|
||||
self.initial_alpha = 1.0
|
||||
self.target_entropy = "auto"
|
||||
self.n_step = 1
|
||||
self.replay_buffer_config = {
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"capacity": int(1e6),
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1500,
|
||||
# If True prioritized replay buffer will be used.
|
||||
"prioritized_replay": False,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
"prioritized_replay_beta": 0.4,
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
# Whether to compute priorities already on the remote worker side.
|
||||
"worker_side_prioritization": False,
|
||||
}
|
||||
self.store_buffer_in_checkpoints = False
|
||||
self.training_intensity = None
|
||||
self.optimization = {
|
||||
"actor_learning_rate": 3e-4,
|
||||
"critic_learning_rate": 3e-4,
|
||||
"entropy_learning_rate": 3e-4,
|
||||
}
|
||||
self.grad_clip = None
|
||||
self.target_network_update_freq = 0
|
||||
|
||||
# The intensity with which to update the model (vs collecting samples from
|
||||
# the env). If None, uses the "natural" value of:
|
||||
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
|
||||
# `num_envs_per_worker`).
|
||||
# If provided, will make sure that the ratio between ts inserted into and
|
||||
# sampled from the buffer matches the given value.
|
||||
# Example:
|
||||
# training_intensity=1000.0
|
||||
# train_batch_size=250 rollout_fragment_length=1
|
||||
# num_workers=1 (or 0) num_envs_per_worker=1
|
||||
# -> natural value = 250 / 1 = 250.0
|
||||
# -> will make sure that replay+train op will be executed 4x as
|
||||
# often as rollout+insert op (4 * 250 = 1000).
|
||||
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
|
||||
"training_intensity": None,
|
||||
# .rollout()
|
||||
self.rollout_fragment_length = 1
|
||||
self.compress_observations = False
|
||||
|
||||
# === Optimization ===
|
||||
"optimization": {
|
||||
"actor_learning_rate": 3e-4,
|
||||
"critic_learning_rate": 3e-4,
|
||||
"entropy_learning_rate": 3e-4,
|
||||
},
|
||||
# If not None, clip gradients during optimization at this value.
|
||||
"grad_clip": None,
|
||||
# Update the replay buffer with this many samples at once. Note that this
|
||||
# setting applies per-worker if num_workers > 1.
|
||||
"rollout_fragment_length": 1,
|
||||
# Size of a batched sampled from replay buffer for training.
|
||||
"train_batch_size": 256,
|
||||
# Update the target network every `target_network_update_freq` sample steps.
|
||||
"target_network_update_freq": 0,
|
||||
# .training()
|
||||
self.train_batch_size = 256
|
||||
|
||||
# === Parallelism ===
|
||||
# Whether to use a GPU for local optimization.
|
||||
"num_gpus": 0,
|
||||
# Number of workers for collecting samples with. This only makes sense
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Whether to allocate GPUs for workers (if > 0).
|
||||
"num_gpus_per_worker": 0,
|
||||
# Whether to allocate CPUs for workers (if > 0).
|
||||
"num_cpus_per_worker": 1,
|
||||
# Prevent reporting frequency from going lower than this time span.
|
||||
"min_time_s_per_reporting": 1,
|
||||
# .reporting()
|
||||
self.min_time_s_per_reporting = 1
|
||||
self.min_sample_timesteps_per_reporting = 100
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
# Whether the loss should be calculated deterministically (w/o the
|
||||
# stochastic action sampling step). True only useful for cont. actions and
|
||||
# for debugging!
|
||||
"_deterministic_loss": False,
|
||||
# Use a Beta-distribution instead of a SquashedGaussian for bounded,
|
||||
# continuous action spaces (not recommended, for debugging only).
|
||||
"_use_beta_distribution": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
self._deterministic_loss = False
|
||||
self._use_beta_distribution = False
|
||||
|
||||
self.use_state_preprocessor = DEPRECATED_VALUE
|
||||
self.worker_side_prioritization = DEPRECATED_VALUE
|
||||
|
||||
@override(TrainerConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
twin_q: Optional[bool] = None,
|
||||
q_model_config: Optional[Dict[str, Any]] = None,
|
||||
policy_model_config: Optional[Dict[str, Any]] = None,
|
||||
tau: Optional[float] = None,
|
||||
initial_alpha: Optional[float] = None,
|
||||
target_entropy: Optional[Union[str, float]] = None,
|
||||
n_step: Optional[int] = None,
|
||||
store_buffer_in_checkpoints: Optional[bool] = None,
|
||||
replay_buffer_config: Optional[Dict[str, Any]] = None,
|
||||
training_intensity: Optional[float] = None,
|
||||
clip_actions: Optional[bool] = None,
|
||||
grad_clip: Optional[float] = None,
|
||||
optimization_config: Optional[Dict[str, Any]] = None,
|
||||
target_network_update_freq: Optional[int] = None,
|
||||
_deterministic_loss: Optional[bool] = None,
|
||||
_use_beta_distribution: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> "SACConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
||||
Args:
|
||||
twin_q: Use two Q-networks (instead of one) for action-value estimation.
|
||||
Note: Each Q-network will have its own target network.
|
||||
q_model_config: Model configs for the Q network(s). These will override
|
||||
MODEL_DEFAULTS. This is treated just as the top-level `model` dict in
|
||||
setting up the Q-network(s) (2 if twin_q=True).
|
||||
That means, you can do for different observation spaces:
|
||||
obs=Box(1D) -> Tuple(Box(1D) + Action) -> concat -> post_fcnet
|
||||
obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action
|
||||
-> post_fcnet
|
||||
obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action)
|
||||
-> vision-net -> concat w/ Box(1D) and action -> post_fcnet
|
||||
You can also have SAC use your custom_model as Q-model(s), by simply
|
||||
specifying the `custom_model` sub-key in below dict (just like you would
|
||||
do in the top-level `model` dict.
|
||||
policy_model_config: Model options for the policy function (see
|
||||
`q_model_config` above for details). The difference to `q_model_config`
|
||||
above is that no action concat'ing is performed before the post_fcnet
|
||||
stack.
|
||||
tau: Update the target by \tau * policy + (1-\tau) * target_policy.
|
||||
initial_alpha: Initial value to use for the entropy weight alpha.
|
||||
target_entropy: Target entropy lower bound. If "auto", will be set
|
||||
to -|A| (e.g. -2.0 for Discrete(2), -3.0 for Box(shape=(3,))).
|
||||
This is the inverse of reward scale, and will be optimized
|
||||
automatically.
|
||||
n_step: N-step target updates. If >1, sars' tuples in trajectories will be
|
||||
postprocessed to become sa[discounted sum of R][s t+n] tuples.
|
||||
store_buffer_in_checkpoints: Set this to True, if you want the contents of
|
||||
your buffer(s) to be stored in any saved checkpoints as well.
|
||||
Warnings will be created if:
|
||||
- This is True AND restoring from a checkpoint that contains no buffer
|
||||
data.
|
||||
- This is False AND restoring from a checkpoint that does contain
|
||||
buffer data.
|
||||
replay_buffer_config: Replay buffer config.
|
||||
Examples:
|
||||
{
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"learning_starts": 1000,
|
||||
"capacity": 50000,
|
||||
"replay_batch_size": 32,
|
||||
"replay_sequence_length": 1,
|
||||
}
|
||||
- OR -
|
||||
{
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"capacity": 50000,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
"prioritized_replay_beta": 0.4,
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
"replay_sequence_length": 1,
|
||||
}
|
||||
- Where -
|
||||
prioritized_replay_alpha: Alpha parameter controls the degree of
|
||||
prioritization in the buffer. In other words, when a buffer sample has
|
||||
a higher temporal-difference error, with how much more probability
|
||||
should it drawn to use to update the parametrized Q-network. 0.0
|
||||
corresponds to uniform probability. Setting much above 1.0 may quickly
|
||||
result as the sampling distribution could become heavily “pointy” with
|
||||
low entropy.
|
||||
prioritized_replay_beta: Beta parameter controls the degree of
|
||||
importance sampling which suppresses the influence of gradient updates
|
||||
from samples that have higher probability of being sampled via alpha
|
||||
parameter and the temporal-difference error.
|
||||
prioritized_replay_eps: Epsilon parameter sets the baseline probability
|
||||
for sampling so that when the temporal-difference error of a sample is
|
||||
zero, there is still a chance of drawing the sample.
|
||||
training_intensity: The intensity with which to update the model (vs
|
||||
collecting samples from the env).
|
||||
If None, uses "natural" values of:
|
||||
`train_batch_size` / (`rollout_fragment_length` x `num_workers` x
|
||||
`num_envs_per_worker`).
|
||||
If not None, will make sure that the ratio between timesteps inserted
|
||||
into and sampled from th buffer matches the given values.
|
||||
Example:
|
||||
training_intensity=1000.0
|
||||
train_batch_size=250
|
||||
rollout_fragment_length=1
|
||||
num_workers=1 (or 0)
|
||||
num_envs_per_worker=1
|
||||
-> natural value = 250 / 1 = 250.0
|
||||
-> will make sure that replay+train op will be executed 4x asoften as
|
||||
rollout+insert op (4 * 250 = 1000).
|
||||
See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
|
||||
clip_actions: Whether to clip actions. If actions are already normalized,
|
||||
this should be set to False.
|
||||
grad_clip: If not None, clip gradients during optimization at this value.
|
||||
optimization_config: Config dict for optimization. Set the supported keys
|
||||
`actor_learning_rate`, `critic_learning_rate`, and
|
||||
`entropy_learning_rate` in here.
|
||||
target_network_update_freq: Update the target network every
|
||||
`target_network_update_freq` steps.
|
||||
_deterministic_loss: Whether the loss should be calculated deterministically
|
||||
(w/o the stochastic action sampling step). True only useful for
|
||||
continuous actions and for debugging.
|
||||
_use_beta_distribution: Use a Beta-distribution instead of a
|
||||
`SquashedGaussian` for bounded, continuous action spaces (not
|
||||
recommended; for debugging only).
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
||||
if twin_q is not None:
|
||||
self.twin_q = twin_q
|
||||
if q_model_config is not None:
|
||||
self.q_model_config = q_model_config
|
||||
if policy_model_config is not None:
|
||||
self.policy_model_config = policy_model_config
|
||||
if tau is not None:
|
||||
self.tau = tau
|
||||
if initial_alpha is not None:
|
||||
self.initial_alpha = initial_alpha
|
||||
if target_entropy is not None:
|
||||
self.target_entropy = target_entropy
|
||||
if n_step is not None:
|
||||
self.n_step = n_step
|
||||
if store_buffer_in_checkpoints is not None:
|
||||
self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
|
||||
if replay_buffer_config is not None:
|
||||
self.replay_buffer_config = replay_buffer_config
|
||||
if training_intensity is not None:
|
||||
self.training_intensity = training_intensity
|
||||
if clip_actions is not None:
|
||||
self.clip_actions = clip_actions
|
||||
if grad_clip is not None:
|
||||
self.grad_clip = grad_clip
|
||||
if optimization_config is not None:
|
||||
self.optimization_config = optimization_config
|
||||
if target_network_update_freq is not None:
|
||||
self.target_network_update_freq = target_network_update_freq
|
||||
if _deterministic_loss is not None:
|
||||
self._deterministic_loss = _deterministic_loss
|
||||
if _use_beta_distribution is not None:
|
||||
self._use_beta_distribution = _use_beta_distribution
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class SACTrainer(DQNTrainer):
|
||||
|
@ -183,13 +282,13 @@ class SACTrainer(DQNTrainer):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._allow_unknown_subkeys += ["policy_model", "Q_model"]
|
||||
self._allow_unknown_subkeys += ["policy_model_config", "q_model_config"]
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@override(DQNTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
return SACConfig().to_dict()
|
||||
|
||||
@override(DQNTrainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
|
@ -200,6 +299,22 @@ class SACTrainer(DQNTrainer):
|
|||
deprecation_warning(old="config['use_state_preprocessor']", error=False)
|
||||
config["use_state_preprocessor"] = DEPRECATED_VALUE
|
||||
|
||||
if config.get("policy_model", DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['policy_model']",
|
||||
new="config['policy_model_config']",
|
||||
error=False,
|
||||
)
|
||||
config["policy_model_config"] = config["policy_model"]
|
||||
|
||||
if config.get("Q_model", DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['Q_model']",
|
||||
new="config['q_model_config']",
|
||||
error=False,
|
||||
)
|
||||
config["q_model_config"] = config["Q_model"]
|
||||
|
||||
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
|
||||
raise ValueError("`grad_clip` value must be > 0.0!")
|
||||
|
||||
|
@ -220,3 +335,20 @@ class SACTrainer(DQNTrainer):
|
|||
return SACTorchPolicy
|
||||
else:
|
||||
return SACTFPolicy
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.algorithms.sac.SACConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(SACConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.algorithms.sac.sac.DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.sac.sac.SACConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -19,9 +19,9 @@ class SACTFModel(TFModelV2):
|
|||
|
||||
To customize, do one of the following:
|
||||
- sub-class SACTFModel and override one or more of its methods.
|
||||
- Use SAC's `Q_model` and `policy_model` keys to tweak the default model
|
||||
- Use SAC's `q_model_config` and `policy_model` keys to tweak the default model
|
||||
behaviors (e.g. fcnet_hiddens, conv_filters, etc..).
|
||||
- Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys
|
||||
- Use SAC's `q_model_config->custom_model` and `policy_model->custom_model` keys
|
||||
to specify your own custom Q-model(s) and policy-models, which will be
|
||||
created within this SACTFModel (see `build_policy_model` and
|
||||
`build_q_model`.
|
||||
|
@ -160,7 +160,7 @@ class SACTFModel(TFModelV2):
|
|||
|
||||
Override this method in a sub-class of SACTFModel to implement your
|
||||
own Q-nets. Alternatively, simply set `custom_model` within the
|
||||
top level SAC `Q_model` config key to make this default implementation
|
||||
top level SAC `q_model_config` config key to make this default implementation
|
||||
of `build_q_model` use your custom Q-nets.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -72,12 +72,12 @@ def build_sac_model(
|
|||
`policy.target_model`.
|
||||
"""
|
||||
# Force-ignore any additionally provided hidden layer sizes.
|
||||
# Everything should be configured using SAC's "Q_model" and "policy_model"
|
||||
# settings.
|
||||
# Everything should be configured using SAC's `q_model_config` and
|
||||
# `policy_model_config` config settings.
|
||||
policy_model_config = copy.deepcopy(MODEL_DEFAULTS)
|
||||
policy_model_config.update(config["policy_model"])
|
||||
policy_model_config.update(config["policy_model_config"])
|
||||
q_model_config = copy.deepcopy(MODEL_DEFAULTS)
|
||||
q_model_config.update(config["Q_model"])
|
||||
q_model_config.update(config["q_model_config"])
|
||||
|
||||
default_model_cls = SACTorchModel if config["framework"] == "torch" else SACTFModel
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
|||
|
||||
To customize, do one of the following:
|
||||
- sub-class SACTorchModel and override one or more of its methods.
|
||||
- Use SAC's `Q_model` and `policy_model` keys to tweak the default model
|
||||
- Use SAC's `q_model_config` and `policy_model` keys to tweak the default model
|
||||
behaviors (e.g. fcnet_hiddens, conv_filters, etc..).
|
||||
- Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys
|
||||
- Use SAC's `q_model_config->custom_model` and `policy_model->custom_model` keys
|
||||
to specify your own custom Q-model(s) and policy-models, which will be
|
||||
created within this SACTFModel (see `build_policy_model` and
|
||||
`build_q_model`.
|
||||
|
@ -168,7 +168,7 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
|||
|
||||
Override this method in a sub-class of SACTFModel to implement your
|
||||
own Q-nets. Alternatively, simply set `custom_model` within the
|
||||
top level SAC `Q_model` config key to make this default implementation
|
||||
top level SAC `q_model_config` config key to make this default implementation
|
||||
of `build_q_model` use your custom Q-nets.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.algorithms.sac as sac
|
||||
from ray.rllib.algorithms import sac
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import check_compute_single_action, framework_iterator
|
||||
|
||||
|
@ -20,42 +20,39 @@ class TestRNNSAC(unittest.TestCase):
|
|||
|
||||
def test_rnnsac_compilation(self):
|
||||
"""Test whether a R2D2Trainer can be built on all frameworks."""
|
||||
config = sac.RNNSAC_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
|
||||
# Wrap with an LSTM and use a very simple base-model.
|
||||
config["model"] = {
|
||||
"max_seq_len": 20,
|
||||
}
|
||||
config["policy_model"] = {
|
||||
"use_lstm": True,
|
||||
"lstm_cell_size": 64,
|
||||
"fcnet_hiddens": [10],
|
||||
"lstm_use_prev_action": True,
|
||||
"lstm_use_prev_reward": True,
|
||||
}
|
||||
config["Q_model"] = {
|
||||
"use_lstm": True,
|
||||
"lstm_cell_size": 64,
|
||||
"fcnet_hiddens": [10],
|
||||
"lstm_use_prev_action": True,
|
||||
"lstm_use_prev_reward": True,
|
||||
}
|
||||
|
||||
# Test with MultiAgentPrioritizedReplayBuffer
|
||||
config["replay_buffer_config"] = {
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"replay_burn_in": 20,
|
||||
"zero_init_states": True,
|
||||
}
|
||||
|
||||
config["lr"] = 5e-4
|
||||
|
||||
config = (
|
||||
sac.RNNSACConfig()
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.training(
|
||||
# Wrap with an LSTM and use a very simple base-model.
|
||||
model={"max_seq_len": 20},
|
||||
policy_model_config={
|
||||
"use_lstm": True,
|
||||
"lstm_cell_size": 64,
|
||||
"fcnet_hiddens": [10],
|
||||
"lstm_use_prev_action": True,
|
||||
"lstm_use_prev_reward": True,
|
||||
},
|
||||
q_model_config={
|
||||
"use_lstm": True,
|
||||
"lstm_cell_size": 64,
|
||||
"fcnet_hiddens": [10],
|
||||
"lstm_use_prev_action": True,
|
||||
"lstm_use_prev_reward": True,
|
||||
},
|
||||
replay_buffer_config={
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"replay_burn_in": 20,
|
||||
"zero_init_states": True,
|
||||
},
|
||||
lr=5e-4,
|
||||
)
|
||||
)
|
||||
num_iterations = 1
|
||||
|
||||
# Test building an RNNSAC agent in all frameworks.
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = sac.RNNSACTrainer(config=config, env="CartPole-v0")
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
|
|
|
@ -6,7 +6,7 @@ import re
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.algorithms.sac as sac
|
||||
from ray.rllib.algorithms import sac
|
||||
from ray.rllib.algorithms.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss
|
||||
from ray.rllib.algorithms.sac.sac_torch_policy import actor_critic_loss as loss_torch
|
||||
from ray.rllib.examples.env.random_env import RandomEnv
|
||||
|
@ -74,20 +74,17 @@ class TestSAC(unittest.TestCase):
|
|||
|
||||
def test_sac_compilation(self):
|
||||
"""Tests whether an SACTrainer can be built with all frameworks."""
|
||||
config = sac.DEFAULT_CONFIG.copy()
|
||||
config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config["n_step"] = 3
|
||||
config["twin_q"] = True
|
||||
config["replay_buffer_config"]["learning_starts"] = 0
|
||||
config["rollout_fragment_length"] = 10
|
||||
config["train_batch_size"] = 10
|
||||
# If we use default buffer size (1e6), the buffer will take up
|
||||
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
|
||||
# available system memory (8.34816 GB).
|
||||
config["replay_buffer_config"]["capacity"] = 40000
|
||||
# Test with saved replay buffer.
|
||||
config["store_buffer_in_checkpoints"] = True
|
||||
config = (
|
||||
sac.SACConfig()
|
||||
.training(
|
||||
n_step=3,
|
||||
twin_q=True,
|
||||
replay_buffer_config={"learning_starts": 0, "capacity": 40000},
|
||||
store_buffer_in_checkpoints=True,
|
||||
train_batch_size=10,
|
||||
)
|
||||
.rollouts(num_rollout_workers=0, rollout_fragment_length=10)
|
||||
)
|
||||
num_iterations = 1
|
||||
|
||||
ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
|
||||
|
@ -134,12 +131,12 @@ class TestSAC(unittest.TestCase):
|
|||
print("Env={}".format(env))
|
||||
# Test making the Q-model a custom one for CartPole, otherwise,
|
||||
# use the default model.
|
||||
config["Q_model"]["custom_model"] = (
|
||||
config.q_model_config["custom_model"] = (
|
||||
"batch_norm{}".format("_torch" if fw == "torch" else "")
|
||||
if env == "CartPole-v0"
|
||||
else None
|
||||
)
|
||||
trainer = sac.SACTrainer(config=config, env=env)
|
||||
trainer = config.build(env=env)
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
check_train_results(results)
|
||||
|
@ -167,22 +164,25 @@ class TestSAC(unittest.TestCase):
|
|||
|
||||
def test_sac_loss_function(self):
|
||||
"""Tests SAC loss function results across all frameworks."""
|
||||
config = sac.DEFAULT_CONFIG.copy()
|
||||
# Run locally.
|
||||
config["seed"] = 42
|
||||
config["num_workers"] = 0
|
||||
config["replay_buffer_config"]["learning_starts"] = 0
|
||||
config["twin_q"] = False
|
||||
config["gamma"] = 0.99
|
||||
# Switch on deterministic loss so we can compare the loss values.
|
||||
config["_deterministic_loss"] = True
|
||||
# Use very simple nets.
|
||||
config["Q_model"]["fcnet_hiddens"] = [10]
|
||||
config["policy_model"]["fcnet_hiddens"] = [10]
|
||||
# Make sure, timing differences do not affect trainer.train().
|
||||
config["min_time_s_per_reporting"] = 0
|
||||
# Test SAC with Simplex action space.
|
||||
config["env_config"] = {"simplex_actions": True}
|
||||
config = (
|
||||
sac.SACConfig()
|
||||
.training(
|
||||
twin_q=False,
|
||||
gamma=0.99,
|
||||
_deterministic_loss=True,
|
||||
q_model_config={"fcnet_hiddens": [10]},
|
||||
policy_model_config={"fcnet_hiddens": [10]},
|
||||
replay_buffer_config={"learning_starts": 0},
|
||||
)
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.reporting(
|
||||
min_time_s_per_reporting=0,
|
||||
)
|
||||
.environment(
|
||||
env_config={"simplex_actions": True},
|
||||
)
|
||||
.debugging(seed=42)
|
||||
)
|
||||
|
||||
map_ = {
|
||||
# Action net.
|
||||
|
@ -247,7 +247,7 @@ class TestSAC(unittest.TestCase):
|
|||
config, frameworks=("tf", "torch"), session=True
|
||||
):
|
||||
# Generate Trainer and get its default Policy object.
|
||||
trainer = sac.SACTrainer(config=config, env=env)
|
||||
trainer = config.build(env=env)
|
||||
policy = trainer.get_policy()
|
||||
p_sess = None
|
||||
if sess:
|
||||
|
@ -287,7 +287,7 @@ class TestSAC(unittest.TestCase):
|
|||
sorted(weights_dict.keys()),
|
||||
log_alpha,
|
||||
fw,
|
||||
gamma=config["gamma"],
|
||||
gamma=config.gamma,
|
||||
sess=sess,
|
||||
)
|
||||
|
||||
|
@ -520,19 +520,22 @@ class TestSAC(unittest.TestCase):
|
|||
return dict_samples[self.steps], 1, self.steps >= 5, {}
|
||||
|
||||
tune.register_env("nested", lambda _: NestedDictEnv())
|
||||
|
||||
config = sac.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config["replay_buffer_config"]["learning_starts"] = 0
|
||||
config["rollout_fragment_length"] = 5
|
||||
config["train_batch_size"] = 5
|
||||
config["replay_buffer_config"]["capacity"] = 10
|
||||
# Disable preprocessors.
|
||||
config["_disable_preprocessor_api"] = True
|
||||
config = (
|
||||
sac.SACConfig()
|
||||
.training(
|
||||
replay_buffer_config={"learning_starts": 0, "capacity": 10},
|
||||
train_batch_size=5,
|
||||
)
|
||||
.rollouts(
|
||||
num_rollout_workers=0,
|
||||
rollout_fragment_length=5,
|
||||
)
|
||||
.experimental(_disable_preprocessor_api=True)
|
||||
)
|
||||
num_iterations = 1
|
||||
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
trainer = sac.SACTrainer(env="nested", config=config)
|
||||
trainer = config.build(env="nested")
|
||||
for _ in range(num_iterations):
|
||||
results = trainer.train()
|
||||
check_train_results(results)
|
||||
|
|
|
@ -139,7 +139,6 @@ class SlateQConfig(TrainerConfig):
|
|||
rmsprop_epsilon: Optional[float] = None,
|
||||
grad_clip: Optional[float] = None,
|
||||
n_step: Optional[int] = None,
|
||||
worker_side_prioritization: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> "SlateQConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
@ -172,8 +171,6 @@ class SlateQConfig(TrainerConfig):
|
|||
rmsprop_epsilon: RMSProp epsilon hyperparameter.
|
||||
grad_clip: If not None, clip gradients during optimization at this value.
|
||||
n_step: N-step parameter for Q-learning.
|
||||
worker_side_prioritization: Whether to compute priorities for the replay
|
||||
buffer on the workers.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
|
@ -205,8 +202,6 @@ class SlateQConfig(TrainerConfig):
|
|||
self.grad_clip = grad_clip
|
||||
if n_step is not None:
|
||||
self.n_step = n_step
|
||||
if worker_side_prioritization is not None:
|
||||
self.worker_side_prioritization = worker_side_prioritization
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ if __name__ == "__main__":
|
|||
|
||||
# See rllib/tuned_examples/cql/pendulum-cql.yaml for comparison.
|
||||
|
||||
config = cql.CQL_DEFAULT_CONFIG.copy()
|
||||
config = cql.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config["horizon"] = 200
|
||||
config["soft_horizon"] = True
|
||||
|
@ -45,11 +45,11 @@ if __name__ == "__main__":
|
|||
config["replay_buffer_config"]["capacity"] = int(1e6)
|
||||
config["tau"] = 0.005
|
||||
config["target_entropy"] = "auto"
|
||||
config["Q_model"] = {
|
||||
config["q_model_config"] = {
|
||||
"fcnet_hiddens": [256, 256],
|
||||
"fcnet_activation": "relu",
|
||||
}
|
||||
config["policy_model"] = {
|
||||
config["policy_model_config"] = {
|
||||
"fcnet_hiddens": [256, 256],
|
||||
"fcnet_activation": "relu",
|
||||
}
|
||||
|
|
|
@ -55,14 +55,14 @@ config = {
|
|||
"model": {
|
||||
"max_seq_len": 20,
|
||||
},
|
||||
"policy_model": {
|
||||
"policy_model_config": {
|
||||
"use_lstm": True,
|
||||
"lstm_cell_size": 64,
|
||||
"fcnet_hiddens": [64, 64],
|
||||
"lstm_use_prev_action": True,
|
||||
"lstm_use_prev_reward": True,
|
||||
},
|
||||
"Q_model": {
|
||||
"q_model_config": {
|
||||
"use_lstm": True,
|
||||
"lstm_cell_size": 64,
|
||||
"fcnet_hiddens": [64, 64],
|
||||
|
|
|
@ -177,8 +177,8 @@ class TestComputeLogLikelihood(unittest.TestCase):
|
|||
"""Tests SAC's (cont. actions) compute_log_likelihoods method."""
|
||||
config = sac.DEFAULT_CONFIG.copy()
|
||||
config["seed"] = 42
|
||||
config["policy_model"]["fcnet_hiddens"] = [10]
|
||||
config["policy_model"]["fcnet_activation"] = "linear"
|
||||
config["policy_model_config"]["fcnet_hiddens"] = [10]
|
||||
config["policy_model_config"]["fcnet_activation"] = "linear"
|
||||
prev_a = np.array([0.0])
|
||||
|
||||
# SAC cont uses a squashed normal distribution. Implement it's logp
|
||||
|
@ -210,8 +210,8 @@ class TestComputeLogLikelihood(unittest.TestCase):
|
|||
"""Tests SAC's (discrete actions) compute_log_likelihoods method."""
|
||||
config = sac.DEFAULT_CONFIG.copy()
|
||||
config["seed"] = 42
|
||||
config["policy_model"]["fcnet_hiddens"] = [10]
|
||||
config["policy_model"]["fcnet_activation"] = "linear"
|
||||
config["policy_model_config"]["fcnet_hiddens"] = [10]
|
||||
config["policy_model_config"]["fcnet_activation"] = "linear"
|
||||
prev_a = np.array(0)
|
||||
|
||||
do_test_log_likelihood(sac.SACTrainer, config, prev_a)
|
||||
|
|
|
@ -15,10 +15,10 @@ halfcheetah_bc:
|
|||
framework: torch
|
||||
soft_horizon: False
|
||||
horizon: 1000
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
tau: 0.005
|
||||
|
|
|
@ -17,10 +17,10 @@ halfcheetah_cql:
|
|||
framework: tf
|
||||
soft_horizon: False
|
||||
horizon: 1000
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
tau: 0.005
|
||||
|
|
|
@ -15,10 +15,10 @@ hopper_bc:
|
|||
framework: torch
|
||||
soft_horizon: False
|
||||
horizon: 1000
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
tau: 0.005
|
||||
|
|
|
@ -15,10 +15,10 @@ hopper_cql:
|
|||
framework: torch
|
||||
soft_horizon: False
|
||||
horizon: 1000
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
tau: 0.005
|
||||
|
|
|
@ -31,7 +31,7 @@ ddpg-hopperbulletenv-v0:
|
|||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
worker_side_prioritization: False
|
||||
worker_side_prioritization: false
|
||||
learning_starts: 500
|
||||
clip_rewards: False
|
||||
actor_lr: 0.001
|
||||
|
|
|
@ -37,7 +37,7 @@ mountaincarcontinuous-ddpg:
|
|||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
worker_side_prioritization: False
|
||||
worker_side_prioritization: false
|
||||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
|
|
|
@ -36,7 +36,7 @@ pendulum-ddpg:
|
|||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 10000
|
||||
worker_side_prioritization: False
|
||||
worker_side_prioritization: false
|
||||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
|
|
|
@ -14,10 +14,10 @@ atari-sac-tf-and-torch:
|
|||
framework:
|
||||
grid_search: [tf, torch]
|
||||
gamma: 0.99
|
||||
Q_model:
|
||||
q_model_config:
|
||||
hidden_activation: relu
|
||||
hidden_layer_sizes: [512]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
hidden_activation: relu
|
||||
hidden_layer_sizes: [512]
|
||||
# Do hard syncs.
|
||||
|
|
|
@ -8,10 +8,10 @@ halfcheetah-pybullet-sac:
|
|||
framework: tf
|
||||
horizon: 1000
|
||||
soft_horizon: false
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
tau: 0.005
|
||||
|
|
|
@ -9,10 +9,10 @@ halfcheetah_sac:
|
|||
framework: tf
|
||||
horizon: 1000
|
||||
soft_horizon: false
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
tau: 0.005
|
||||
|
|
|
@ -11,10 +11,10 @@ mspacman-sac-tf:
|
|||
# Works for both torch and tf.
|
||||
framework: tf
|
||||
gamma: 0.99
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_hiddens: [512]
|
||||
fcnet_activation: relu
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_hiddens: [512]
|
||||
fcnet_activation: relu
|
||||
# Do hard syncs.
|
||||
|
|
|
@ -10,10 +10,10 @@ pendulum-sac-fake-gpus:
|
|||
framework: tf
|
||||
horizon: 200
|
||||
soft_horizon: false
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [ 256, 256 ]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [ 256, 256 ]
|
||||
tau: 0.005
|
||||
|
|
|
@ -12,10 +12,10 @@ pendulum-sac:
|
|||
framework: tf
|
||||
horizon: 200
|
||||
soft_horizon: false
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
tau: 0.005
|
||||
|
|
|
@ -19,10 +19,10 @@ transformed-actions-pendulum-sac-dummy-torch:
|
|||
|
||||
horizon: 200
|
||||
soft_horizon: false
|
||||
Q_model:
|
||||
q_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
policy_model:
|
||||
policy_model_config:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
tau: 0.005
|
||||
|
|
|
@ -138,6 +138,13 @@ def validate_buffer_config(config: dict):
|
|||
if config.get("replay_buffer_config", None) is None:
|
||||
config["replay_buffer_config"] = {}
|
||||
|
||||
if config.get("worker_side_prioritization", DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['worker_side_prioritization']",
|
||||
new="config['replay_buffer_config']['worker_side_prioritization']",
|
||||
error=True,
|
||||
)
|
||||
|
||||
prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE)
|
||||
if prioritized_replay != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
|
|
Loading…
Add table
Reference in a new issue