2020-09-20 11:27:02 +02:00
|
|
|
"""
|
|
|
|
TensorFlow policy class used for SAC.
|
|
|
|
"""
|
|
|
|
|
2022-03-31 18:26:12 +02:00
|
|
|
import copy
|
2020-09-20 11:27:02 +02:00
|
|
|
import gym
|
2020-04-15 13:25:16 +02:00
|
|
|
from gym.spaces import Box, Discrete
|
2020-09-12 03:04:44 +08:00
|
|
|
from functools import partial
|
2019-08-01 23:37:36 -07:00
|
|
|
import logging
|
2020-09-20 11:27:02 +02:00
|
|
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
import ray
|
|
|
|
import ray.experimental.tf_utils
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms.ddpg.ddpg_tf_policy import (
|
|
|
|
ComputeTDErrorMixin,
|
|
|
|
TargetNetworkMixin,
|
|
|
|
)
|
|
|
|
from ray.rllib.algorithms.dqn.dqn_tf_policy import (
|
|
|
|
postprocess_nstep_and_prio,
|
|
|
|
PRIO_WEIGHTS,
|
|
|
|
)
|
|
|
|
from ray.rllib.algorithms.sac.sac_tf_model import SACTFModel
|
|
|
|
from ray.rllib.algorithms.sac.sac_torch_model import SACTorchModel
|
2021-10-29 12:03:56 +02:00
|
|
|
from ray.rllib.evaluation.episode import Episode
|
2021-02-02 13:05:58 +01:00
|
|
|
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
|
2020-09-20 11:27:02 +02:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2020-04-30 20:09:33 +02:00
|
|
|
from ray.rllib.models.tf.tf_action_dist import (
|
|
|
|
Beta,
|
|
|
|
Categorical,
|
2020-11-11 18:45:28 +01:00
|
|
|
DiagGaussian,
|
|
|
|
Dirichlet,
|
|
|
|
SquashedGaussian,
|
|
|
|
TFActionDistribution,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-09-20 11:27:02 +02:00
|
|
|
from ray.rllib.policy.policy import Policy
|
2020-04-01 09:43:21 +02:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
2020-03-06 19:37:12 +01:00
|
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
2021-09-15 08:46:37 +02:00
|
|
|
from ray.rllib.utils.framework import get_variable, try_import_tf
|
2020-11-11 18:45:28 +01:00
|
|
|
from ray.rllib.utils.spaces.simplex import Simplex
|
2021-11-01 21:46:02 +01:00
|
|
|
from ray.rllib.utils.tf_utils import huber_loss
|
2020-09-20 11:27:02 +02:00
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
AgentID,
|
|
|
|
LocalOptimizer,
|
|
|
|
ModelGradients,
|
|
|
|
TensorType,
|
2022-06-11 15:10:39 +02:00
|
|
|
AlgorithmConfigDict,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-02-24 01:10:20 +01:00
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def build_sac_model(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
2022-06-11 15:10:39 +02:00
|
|
|
config: AlgorithmConfigDict,
|
2020-09-20 11:27:02 +02:00
|
|
|
) -> ModelV2:
|
|
|
|
"""Constructs the necessary ModelV2 for the Policy and returns it.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The TFPolicy that will use the models.
|
2020-09-20 11:27:02 +02:00
|
|
|
obs_space (gym.spaces.Space): The observation space.
|
|
|
|
action_space (gym.spaces.Space): The action space.
|
2022-06-01 11:27:54 -07:00
|
|
|
config: The SAC trainer's config dict.
|
2020-09-20 11:27:02 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
ModelV2: The ModelV2 to be used by the Policy. Note: An additional
|
|
|
|
target model will be created in this function and assigned to
|
|
|
|
`policy.target_model`.
|
|
|
|
"""
|
2020-04-15 13:25:16 +02:00
|
|
|
# Force-ignore any additionally provided hidden layer sizes.
|
2022-05-22 18:58:47 +01:00
|
|
|
# Everything should be configured using SAC's `q_model_config` and
|
|
|
|
# `policy_model_config` config settings.
|
2022-03-31 18:26:12 +02:00
|
|
|
policy_model_config = copy.deepcopy(MODEL_DEFAULTS)
|
2022-05-22 18:58:47 +01:00
|
|
|
policy_model_config.update(config["policy_model_config"])
|
2022-03-31 18:26:12 +02:00
|
|
|
q_model_config = copy.deepcopy(MODEL_DEFAULTS)
|
2022-05-22 18:58:47 +01:00
|
|
|
q_model_config.update(config["q_model_config"])
|
2021-02-02 13:05:58 +01:00
|
|
|
|
|
|
|
default_model_cls = SACTorchModel if config["framework"] == "torch" else SACTFModel
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
model = ModelCatalog.get_model_v2(
|
2020-04-15 13:25:16 +02:00
|
|
|
obs_space=obs_space,
|
|
|
|
action_space=action_space,
|
2021-10-06 09:05:50 +02:00
|
|
|
num_outputs=None,
|
2020-04-15 13:25:16 +02:00
|
|
|
model_config=config["model"],
|
2020-05-27 16:19:13 +02:00
|
|
|
framework=config["framework"],
|
2021-02-02 13:05:58 +01:00
|
|
|
default_model=default_model_cls,
|
2019-08-01 23:37:36 -07:00
|
|
|
name="sac_model",
|
2021-02-02 13:05:58 +01:00
|
|
|
policy_model_config=policy_model_config,
|
|
|
|
q_model_config=q_model_config,
|
2020-03-06 19:37:12 +01:00
|
|
|
twin_q=config["twin_q"],
|
2020-04-15 13:25:16 +02:00
|
|
|
initial_alpha=config["initial_alpha"],
|
|
|
|
target_entropy=config["target_entropy"],
|
|
|
|
)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2021-02-02 13:05:58 +01:00
|
|
|
assert isinstance(model, default_model_cls)
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
# Create an exact copy of the model and store it in `policy.target_model`.
|
|
|
|
# This will be used for tau-synched Q-target models that run behind the
|
|
|
|
# actual Q-networks and are used for target q-value calculations in the
|
|
|
|
# loss terms.
|
2019-08-01 23:37:36 -07:00
|
|
|
policy.target_model = ModelCatalog.get_model_v2(
|
2020-04-15 13:25:16 +02:00
|
|
|
obs_space=obs_space,
|
|
|
|
action_space=action_space,
|
2021-10-06 09:05:50 +02:00
|
|
|
num_outputs=None,
|
2020-04-15 13:25:16 +02:00
|
|
|
model_config=config["model"],
|
2020-05-27 16:19:13 +02:00
|
|
|
framework=config["framework"],
|
2021-02-02 13:05:58 +01:00
|
|
|
default_model=default_model_cls,
|
2019-08-01 23:37:36 -07:00
|
|
|
name="target_sac_model",
|
2021-02-02 13:05:58 +01:00
|
|
|
policy_model_config=policy_model_config,
|
|
|
|
q_model_config=q_model_config,
|
2020-03-06 19:37:12 +01:00
|
|
|
twin_q=config["twin_q"],
|
2020-04-15 13:25:16 +02:00
|
|
|
initial_alpha=config["initial_alpha"],
|
|
|
|
target_entropy=config["target_entropy"],
|
|
|
|
)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2021-02-02 13:05:58 +01:00
|
|
|
assert isinstance(policy.target_model, default_model_cls)
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess_trajectory(
|
|
|
|
policy: Policy,
|
|
|
|
sample_batch: SampleBatch,
|
|
|
|
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
2021-10-29 12:03:56 +02:00
|
|
|
episode: Optional[Episode] = None,
|
|
|
|
) -> SampleBatch:
|
2020-09-20 11:27:02 +02:00
|
|
|
"""Postprocesses a trajectory and returns the processed trajectory.
|
|
|
|
|
|
|
|
The trajectory contains only data from one episode and from one agent.
|
|
|
|
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
|
|
|
|
contain a truncated (at-the-end) episode, in case the
|
|
|
|
`config.rollout_fragment_length` was reached by the sampler.
|
|
|
|
- If `config.batch_mode=complete_episodes`, sample_batch will contain
|
|
|
|
exactly one episode (no matter how long).
|
|
|
|
New columns can be added to sample_batch and existing ones may be altered.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy used to generate the trajectory
|
2020-09-20 11:27:02 +02:00
|
|
|
(`sample_batch`)
|
2022-06-01 11:27:54 -07:00
|
|
|
sample_batch: The SampleBatch to postprocess.
|
2020-09-20 11:27:02 +02:00
|
|
|
other_agent_batches (Optional[Dict[AgentID, SampleBatch]]): Optional
|
|
|
|
dict of AgentIDs mapping to other agents' trajectory data (from the
|
|
|
|
same episode). NOTE: The other agents use the same policy.
|
2021-10-29 12:03:56 +02:00
|
|
|
episode (Optional[Episode]): Optional multi-agent episode
|
2020-09-20 11:27:02 +02:00
|
|
|
object in which the agents operated.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
|
|
|
|
"""
|
|
|
|
return postprocess_nstep_and_prio(policy, sample_batch)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
|
2021-06-18 17:27:29 +02:00
|
|
|
def _get_dist_class(
|
2022-06-11 15:10:39 +02:00
|
|
|
policy: Policy, config: AlgorithmConfigDict, action_space: gym.spaces.Space
|
2021-06-18 17:27:29 +02:00
|
|
|
) -> Type[TFActionDistribution]:
|
2020-09-20 11:27:02 +02:00
|
|
|
"""Helper function to return a dist class based on config and action space.
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The policy for which to return the action
|
2021-06-18 17:27:29 +02:00
|
|
|
dist class.
|
2022-06-11 15:10:39 +02:00
|
|
|
config: The Algorithm's config dict.
|
2020-09-20 11:27:02 +02:00
|
|
|
action_space (gym.spaces.Space): The action space used.
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Returns:
|
|
|
|
Type[TFActionDistribution]: A TF distribution class.
|
|
|
|
"""
|
2021-06-18 17:27:29 +02:00
|
|
|
if hasattr(policy, "dist_class") and policy.dist_class is not None:
|
|
|
|
return policy.dist_class
|
|
|
|
elif config["model"].get("custom_action_dist"):
|
|
|
|
action_dist_class, _ = ModelCatalog.get_action_dist(
|
|
|
|
action_space, config["model"], framework="tf"
|
|
|
|
)
|
|
|
|
return action_dist_class
|
|
|
|
elif isinstance(action_space, Discrete):
|
2020-04-30 20:09:33 +02:00
|
|
|
return Categorical
|
2020-11-11 18:45:28 +01:00
|
|
|
elif isinstance(action_space, Simplex):
|
|
|
|
return Dirichlet
|
2020-03-06 19:37:12 +01:00
|
|
|
else:
|
2021-06-30 12:32:11 +02:00
|
|
|
assert isinstance(action_space, Box)
|
2020-04-30 20:09:33 +02:00
|
|
|
if config["normalize_actions"]:
|
|
|
|
return SquashedGaussian if not config["_use_beta_distribution"] else Beta
|
|
|
|
else:
|
|
|
|
return DiagGaussian
|
2020-02-22 23:19:49 +01:00
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def get_distribution_inputs_and_class(
|
|
|
|
policy: Policy,
|
|
|
|
model: ModelV2,
|
|
|
|
obs_batch: TensorType,
|
|
|
|
*,
|
|
|
|
explore: bool = True,
|
|
|
|
**kwargs
|
|
|
|
) -> Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]:
|
|
|
|
"""The action distribution function to be used the algorithm.
|
|
|
|
|
|
|
|
An action distribution function is used to customize the choice of action
|
|
|
|
distribution class and the resulting action distribution inputs (to
|
|
|
|
parameterize the distribution object).
|
|
|
|
After parameterizing the distribution, a `sample()` call
|
|
|
|
will be made on it to generate actions.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy being queried for actions and calling this
|
2020-09-20 11:27:02 +02:00
|
|
|
function.
|
2022-06-01 11:27:54 -07:00
|
|
|
model: The SAC specific Model to use to generate the
|
2020-09-20 11:27:02 +02:00
|
|
|
distribution inputs (see sac_tf|torch_model.py). Must support the
|
2022-04-25 09:19:24 +02:00
|
|
|
`get_action_model_outputs` method.
|
2022-06-01 11:27:54 -07:00
|
|
|
obs_batch: The observations to be used as inputs to the
|
2020-09-20 11:27:02 +02:00
|
|
|
model.
|
2022-06-01 11:27:54 -07:00
|
|
|
explore: Whether to activate exploration or not.
|
2020-09-20 11:27:02 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]: The
|
|
|
|
dist inputs, dist class, and a list of internal state outputs
|
|
|
|
(in the RNN case).
|
|
|
|
"""
|
2021-02-02 13:05:58 +01:00
|
|
|
# Get base-model (forward) output (this should be a noop call).
|
2021-12-02 04:11:26 -08:00
|
|
|
forward_out, state_out = model(
|
|
|
|
SampleBatch(obs=obs_batch, _is_training=policy._get_is_training_placeholder()),
|
|
|
|
[],
|
|
|
|
None,
|
|
|
|
)
|
2020-09-20 11:27:02 +02:00
|
|
|
# Use the base output to get the policy outputs from the SAC model's
|
|
|
|
# policy components.
|
2022-04-25 09:19:24 +02:00
|
|
|
distribution_inputs, _ = model.get_action_model_outputs(forward_out)
|
2020-09-20 11:27:02 +02:00
|
|
|
# Get a distribution class to be used with the just calculated dist-inputs.
|
2021-06-18 17:27:29 +02:00
|
|
|
action_dist_class = _get_dist_class(policy, policy.config, policy.action_space)
|
2020-09-20 11:27:02 +02:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
return distribution_inputs, action_dist_class, state_out
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def sac_actor_critic_loss(
|
|
|
|
policy: Policy,
|
|
|
|
model: ModelV2,
|
|
|
|
dist_class: Type[TFActionDistribution],
|
|
|
|
train_batch: SampleBatch,
|
|
|
|
) -> Union[TensorType, List[TensorType]]:
|
|
|
|
"""Constructs the loss for the Soft Actor Critic.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy to calculate the loss for.
|
2020-09-20 11:27:02 +02:00
|
|
|
model (ModelV2): The Model to calculate the loss for.
|
|
|
|
dist_class (Type[ActionDistribution]: The action distr. class.
|
2022-06-01 11:27:54 -07:00
|
|
|
train_batch: The training data.
|
2020-09-20 11:27:02 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
|
|
|
of loss tensors.
|
|
|
|
"""
|
2020-04-15 13:25:16 +02:00
|
|
|
# Should be True only for debugging purposes (e.g. test cases)!
|
|
|
|
deterministic = policy.config["_deterministic_loss"]
|
|
|
|
|
2021-12-02 04:11:26 -08:00
|
|
|
_is_training = policy._get_is_training_placeholder()
|
2020-09-20 11:27:02 +02:00
|
|
|
# Get the base model output from the train batch.
|
2021-12-02 04:11:26 -08:00
|
|
|
model_out_t, _ = model(
|
|
|
|
SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=_is_training),
|
|
|
|
[],
|
|
|
|
None,
|
|
|
|
)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
# Get the base model output from the next observations in the train batch.
|
2021-12-02 04:11:26 -08:00
|
|
|
model_out_tp1, _ = model(
|
|
|
|
SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=_is_training),
|
|
|
|
[],
|
|
|
|
None,
|
|
|
|
)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
# Get the target model's base outputs from the next observations in the
|
|
|
|
# train batch.
|
2021-12-02 04:11:26 -08:00
|
|
|
target_model_out_tp1, _ = policy.target_model(
|
|
|
|
SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=_is_training),
|
|
|
|
[],
|
|
|
|
None,
|
|
|
|
)
|
2020-02-24 01:10:20 +01:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
# Discrete actions case.
|
2020-03-06 19:37:12 +01:00
|
|
|
if model.discrete:
|
|
|
|
# Get all action probs directly from pi and form their logp.
|
2022-04-25 09:19:24 +02:00
|
|
|
action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
|
|
|
|
log_pis_t = tf.nn.log_softmax(action_dist_inputs_t, -1)
|
2020-06-30 10:13:20 +02:00
|
|
|
policy_t = tf.math.exp(log_pis_t)
|
2022-04-25 09:19:24 +02:00
|
|
|
|
|
|
|
action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1)
|
|
|
|
log_pis_tp1 = tf.nn.log_softmax(action_dist_inputs_tp1, -1)
|
2020-06-30 10:13:20 +02:00
|
|
|
policy_tp1 = tf.math.exp(log_pis_tp1)
|
2022-04-25 09:19:24 +02:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
# Q-values.
|
2022-04-25 09:19:24 +02:00
|
|
|
q_t, _ = model.get_q_values(model_out_t)
|
2020-03-06 19:37:12 +01:00
|
|
|
# Target Q-values.
|
2022-04-25 09:19:24 +02:00
|
|
|
q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1)
|
2020-03-06 19:37:12 +01:00
|
|
|
if policy.config["twin_q"]:
|
2022-04-25 09:19:24 +02:00
|
|
|
twin_q_t, _ = model.get_twin_q_values(model_out_t)
|
|
|
|
twin_q_tp1, _ = policy.target_model.get_twin_q_values(target_model_out_tp1)
|
2020-03-06 19:37:12 +01:00
|
|
|
q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
|
|
|
|
q_tp1 -= model.alpha * log_pis_tp1
|
|
|
|
|
|
|
|
# Actually selected Q-values (from the actions batch).
|
|
|
|
one_hot = tf.one_hot(
|
|
|
|
train_batch[SampleBatch.ACTIONS], depth=q_t.shape.as_list()[-1]
|
|
|
|
)
|
|
|
|
q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1)
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1)
|
|
|
|
# Discrete case: "Best" means weighted by the policy (prob) outputs.
|
|
|
|
q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1)
|
|
|
|
q_tp1_best_masked = (
|
|
|
|
1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)
|
|
|
|
) * q_tp1_best
|
|
|
|
# Continuous actions case.
|
|
|
|
else:
|
|
|
|
# Sample simgle actions from distribution.
|
2021-06-18 17:27:29 +02:00
|
|
|
action_dist_class = _get_dist_class(policy, policy.config, policy.action_space)
|
2022-04-25 09:19:24 +02:00
|
|
|
action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
|
|
|
|
action_dist_t = action_dist_class(action_dist_inputs_t, policy.model)
|
2020-04-15 13:25:16 +02:00
|
|
|
policy_t = (
|
|
|
|
action_dist_t.sample()
|
|
|
|
if not deterministic
|
|
|
|
else action_dist_t.deterministic_sample()
|
|
|
|
)
|
|
|
|
log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1)
|
2022-04-25 09:19:24 +02:00
|
|
|
|
|
|
|
action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1)
|
|
|
|
action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, policy.model)
|
2020-04-15 13:25:16 +02:00
|
|
|
policy_tp1 = (
|
|
|
|
action_dist_tp1.sample()
|
|
|
|
if not deterministic
|
|
|
|
else action_dist_tp1.deterministic_sample()
|
|
|
|
)
|
|
|
|
log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1)
|
2020-03-06 19:37:12 +01:00
|
|
|
|
|
|
|
# Q-values for the actually selected actions.
|
2022-04-25 09:19:24 +02:00
|
|
|
q_t, _ = model.get_q_values(
|
2021-04-27 17:18:17 +02:00
|
|
|
model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-06 19:37:12 +01:00
|
|
|
if policy.config["twin_q"]:
|
2022-04-25 09:19:24 +02:00
|
|
|
twin_q_t, _ = model.get_twin_q_values(
|
2021-04-27 17:18:17 +02:00
|
|
|
model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
# Q-values for current policy in given current state.
|
2022-04-25 09:19:24 +02:00
|
|
|
q_t_det_policy, _ = model.get_q_values(model_out_t, policy_t)
|
2020-03-06 19:37:12 +01:00
|
|
|
if policy.config["twin_q"]:
|
2022-04-25 09:19:24 +02:00
|
|
|
twin_q_t_det_policy, _ = model.get_twin_q_values(model_out_t, policy_t)
|
2020-03-06 19:37:12 +01:00
|
|
|
q_t_det_policy = tf.reduce_min(
|
|
|
|
(q_t_det_policy, twin_q_t_det_policy), axis=0
|
|
|
|
)
|
|
|
|
|
|
|
|
# target q network evaluation
|
2022-04-25 09:19:24 +02:00
|
|
|
q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
|
2020-03-06 19:37:12 +01:00
|
|
|
if policy.config["twin_q"]:
|
2022-04-25 09:19:24 +02:00
|
|
|
twin_q_tp1, _ = policy.target_model.get_twin_q_values(
|
2020-03-06 19:37:12 +01:00
|
|
|
target_model_out_tp1, policy_tp1
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
# Take min over both twin-NNs.
|
|
|
|
q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
|
|
|
q_tp1 -= model.alpha * log_pis_tp1
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
|
|
|
q_tp1_best_masked = (
|
|
|
|
1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)
|
|
|
|
) * q_tp1_best
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
# Compute RHS of bellman equation for the Q-loss (critic(s)).
|
2019-08-01 23:37:36 -07:00
|
|
|
q_t_selected_target = tf.stop_gradient(
|
2021-03-24 16:12:55 +01:00
|
|
|
tf.cast(train_batch[SampleBatch.REWARDS], tf.float32)
|
2019-08-01 23:37:36 -07:00
|
|
|
+ policy.config["gamma"] ** policy.config["n_step"] * q_tp1_best_masked
|
|
|
|
)
|
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
# Compute the TD-error (potentially clipped).
|
2020-06-30 10:13:20 +02:00
|
|
|
base_td_error = tf.math.abs(q_t_selected - q_t_selected_target)
|
2019-08-01 23:37:36 -07:00
|
|
|
if policy.config["twin_q"]:
|
2020-06-30 10:13:20 +02:00
|
|
|
twin_td_error = tf.math.abs(twin_q_t_selected - q_t_selected_target)
|
2020-03-06 19:37:12 +01:00
|
|
|
td_error = 0.5 * (base_td_error + twin_td_error)
|
2019-08-01 23:37:36 -07:00
|
|
|
else:
|
2020-03-06 19:37:12 +01:00
|
|
|
td_error = base_td_error
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
# Calculate one or two critic losses (2 in the twin_q case).
|
2020-11-25 20:28:46 +01:00
|
|
|
prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)
|
|
|
|
critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))]
|
2019-12-20 10:51:25 -08:00
|
|
|
if policy.config["twin_q"]:
|
2020-11-25 20:28:46 +01:00
|
|
|
critic_loss.append(tf.reduce_mean(prio_weights * huber_loss(twin_td_error)))
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
# Alpha- and actor losses.
|
|
|
|
# Note: In the papers, alpha is used directly, here we take the log.
|
|
|
|
# Discrete case: Multiply the action probs as weights with the original
|
|
|
|
# loss terms (no expectations needed).
|
|
|
|
if model.discrete:
|
|
|
|
alpha_loss = tf.reduce_mean(
|
|
|
|
tf.reduce_sum(
|
|
|
|
tf.multiply(
|
|
|
|
tf.stop_gradient(policy_t),
|
|
|
|
-model.log_alpha
|
2020-04-15 13:25:16 +02:00
|
|
|
* tf.stop_gradient(log_pis_t + model.target_entropy),
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2020-03-06 19:37:12 +01:00
|
|
|
axis=-1,
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-06 19:37:12 +01:00
|
|
|
actor_loss = tf.reduce_mean(
|
|
|
|
tf.reduce_sum(
|
|
|
|
tf.multiply(
|
|
|
|
# NOTE: No stop_grad around policy output here
|
|
|
|
# (compare with q_t_det_policy for continuous case).
|
|
|
|
policy_t,
|
|
|
|
model.alpha * log_pis_t - tf.stop_gradient(q_t),
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2020-03-06 19:37:12 +01:00
|
|
|
axis=-1,
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-06 19:37:12 +01:00
|
|
|
else:
|
|
|
|
alpha_loss = -tf.reduce_mean(
|
2020-04-15 13:25:16 +02:00
|
|
|
model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-06 19:37:12 +01:00
|
|
|
actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
# Save for stats function.
|
2020-04-15 13:25:16 +02:00
|
|
|
policy.policy_t = policy_t
|
2019-08-01 23:37:36 -07:00
|
|
|
policy.q_t = q_t
|
|
|
|
policy.td_error = td_error
|
|
|
|
policy.actor_loss = actor_loss
|
|
|
|
policy.critic_loss = critic_loss
|
|
|
|
policy.alpha_loss = alpha_loss
|
2020-03-06 19:37:12 +01:00
|
|
|
policy.alpha_value = model.alpha
|
2020-04-15 13:25:16 +02:00
|
|
|
policy.target_entropy = model.target_entropy
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
# In a custom apply op we handle the losses separately, but return them
|
|
|
|
# combined in one loss here.
|
2020-06-30 10:13:20 +02:00
|
|
|
return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def compute_and_clip_gradients(
|
|
|
|
policy: Policy, optimizer: LocalOptimizer, loss: TensorType
|
|
|
|
) -> ModelGradients:
|
|
|
|
"""Gradients computing function (from loss tensor, using local optimizer).
|
|
|
|
|
|
|
|
Note: For SAC, optimizer and loss are ignored b/c we have 3
|
|
|
|
losses and 3 local optimizers (all stored in policy).
|
|
|
|
`optimizer` will be used, though, in the tf-eager case b/c it is then a
|
|
|
|
fake optimizer (OptimizerWrapper) object with a `tape` property to
|
|
|
|
generate a GradientTape object for gradient recording.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy object that generated the loss tensor and
|
2020-09-20 11:27:02 +02:00
|
|
|
that holds the given local optimizer.
|
2022-06-01 11:27:54 -07:00
|
|
|
optimizer: The tf (local) optimizer object to
|
2020-09-20 11:27:02 +02:00
|
|
|
calculate the gradients with.
|
2022-06-01 11:27:54 -07:00
|
|
|
loss: The loss tensor for which gradients should be
|
2020-09-20 11:27:02 +02:00
|
|
|
calculated.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ModelGradients: List of the possibly clipped gradients- and variable
|
|
|
|
tuples.
|
|
|
|
"""
|
|
|
|
# Eager: Use GradientTape (which is a property of the `optimizer` object
|
|
|
|
# (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py).
|
2020-07-11 22:06:35 +02:00
|
|
|
if policy.config["framework"] in ["tf2", "tfe"]:
|
2020-07-08 16:12:20 +02:00
|
|
|
tape = optimizer.tape
|
|
|
|
pol_weights = policy.model.policy_variables()
|
2020-08-07 16:49:49 -07:00
|
|
|
actor_grads_and_vars = list(
|
|
|
|
zip(tape.gradient(policy.actor_loss, pol_weights), pol_weights)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-07-08 16:12:20 +02:00
|
|
|
q_weights = policy.model.q_variables()
|
2019-12-20 10:51:25 -08:00
|
|
|
if policy.config["twin_q"]:
|
2020-07-08 16:12:20 +02:00
|
|
|
half_cutoff = len(q_weights) // 2
|
2020-08-07 16:49:49 -07:00
|
|
|
grads_1 = tape.gradient(policy.critic_loss[0], q_weights[:half_cutoff])
|
|
|
|
grads_2 = tape.gradient(policy.critic_loss[1], q_weights[half_cutoff:])
|
2020-07-08 16:12:20 +02:00
|
|
|
critic_grads_and_vars = list(zip(grads_1, q_weights[:half_cutoff])) + list(
|
|
|
|
zip(grads_2, q_weights[half_cutoff:])
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-12-20 10:51:25 -08:00
|
|
|
else:
|
2020-08-07 16:49:49 -07:00
|
|
|
critic_grads_and_vars = list(
|
|
|
|
zip(tape.gradient(policy.critic_loss[0], q_weights), q_weights)
|
|
|
|
)
|
2020-07-08 16:12:20 +02:00
|
|
|
|
|
|
|
alpha_vars = [policy.model.log_alpha]
|
2020-08-07 16:49:49 -07:00
|
|
|
alpha_grads_and_vars = list(
|
|
|
|
zip(tape.gradient(policy.alpha_loss, alpha_vars), alpha_vars)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-07-08 16:12:20 +02:00
|
|
|
# Tf1.x: Use optimizer.compute_gradients()
|
2019-08-01 23:37:36 -07:00
|
|
|
else:
|
2019-12-20 10:51:25 -08:00
|
|
|
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
|
2019-08-01 23:37:36 -07:00
|
|
|
policy.actor_loss, var_list=policy.model.policy_variables()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-07-08 16:12:20 +02:00
|
|
|
|
|
|
|
q_weights = policy.model.q_variables()
|
2019-12-20 10:51:25 -08:00
|
|
|
if policy.config["twin_q"]:
|
2020-07-08 16:12:20 +02:00
|
|
|
half_cutoff = len(q_weights) // 2
|
2019-12-20 10:51:25 -08:00
|
|
|
base_q_optimizer, twin_q_optimizer = policy._critic_optimizer
|
|
|
|
critic_grads_and_vars = base_q_optimizer.compute_gradients(
|
2020-07-08 16:12:20 +02:00
|
|
|
policy.critic_loss[0], var_list=q_weights[:half_cutoff]
|
2019-12-20 10:51:25 -08:00
|
|
|
) + twin_q_optimizer.compute_gradients(
|
2020-07-08 16:12:20 +02:00
|
|
|
policy.critic_loss[1], var_list=q_weights[half_cutoff:]
|
|
|
|
)
|
2019-12-20 10:51:25 -08:00
|
|
|
else:
|
|
|
|
critic_grads_and_vars = policy._critic_optimizer[0].compute_gradients(
|
2020-07-08 16:12:20 +02:00
|
|
|
policy.critic_loss[0], var_list=q_weights
|
|
|
|
)
|
2019-12-20 10:51:25 -08:00
|
|
|
alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients(
|
2019-08-23 02:21:11 -04:00
|
|
|
policy.alpha_loss, var_list=[policy.model.log_alpha]
|
|
|
|
)
|
2019-12-20 10:51:25 -08:00
|
|
|
|
2020-07-08 16:12:20 +02:00
|
|
|
# Clip if necessary.
|
|
|
|
if policy.config["grad_clip"]:
|
2020-09-12 03:04:44 +08:00
|
|
|
clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"])
|
2020-07-08 16:12:20 +02:00
|
|
|
else:
|
|
|
|
clip_func = tf.identity
|
|
|
|
|
|
|
|
# Save grads and vars for later use in `build_apply_op`.
|
2020-08-07 16:49:49 -07:00
|
|
|
policy._actor_grads_and_vars = [
|
|
|
|
(clip_func(g), v) for (g, v) in actor_grads_and_vars if g is not None
|
|
|
|
]
|
|
|
|
policy._critic_grads_and_vars = [
|
|
|
|
(clip_func(g), v) for (g, v) in critic_grads_and_vars if g is not None
|
|
|
|
]
|
|
|
|
policy._alpha_grads_and_vars = [
|
|
|
|
(clip_func(g), v) for (g, v) in alpha_grads_and_vars if g is not None
|
|
|
|
]
|
2020-07-08 16:12:20 +02:00
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
grads_and_vars = (
|
|
|
|
policy._actor_grads_and_vars
|
|
|
|
+ policy._critic_grads_and_vars
|
|
|
|
+ policy._alpha_grads_and_vars
|
|
|
|
)
|
|
|
|
return grads_and_vars
|
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def apply_gradients(
|
|
|
|
policy: Policy, optimizer: LocalOptimizer, grads_and_vars: ModelGradients
|
|
|
|
) -> Union["tf.Operation", None]:
|
|
|
|
"""Gradients applying function (from list of "grad_and_var" tuples).
|
|
|
|
|
|
|
|
Note: For SAC, optimizer and grads_and_vars are ignored b/c we have 3
|
|
|
|
losses and optimizers (stored in policy).
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy object whose Model(s) the given gradients
|
2020-09-20 11:27:02 +02:00
|
|
|
should be applied to.
|
2022-06-01 11:27:54 -07:00
|
|
|
optimizer: The tf (local) optimizer object through
|
2020-09-20 11:27:02 +02:00
|
|
|
which to apply the gradients.
|
2022-06-01 11:27:54 -07:00
|
|
|
grads_and_vars: The list of grad_and_var tuples to
|
2020-09-20 11:27:02 +02:00
|
|
|
apply via the given optimizer.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Union[tf.Operation, None]: The tf op to be used to run the apply
|
|
|
|
operation. None for eager mode.
|
|
|
|
"""
|
2019-12-20 10:51:25 -08:00
|
|
|
actor_apply_ops = policy._actor_optimizer.apply_gradients(
|
|
|
|
policy._actor_grads_and_vars
|
|
|
|
)
|
|
|
|
|
|
|
|
cgrads = policy._critic_grads_and_vars
|
|
|
|
half_cutoff = len(cgrads) // 2
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
critic_apply_ops = [
|
|
|
|
policy._critic_optimizer[0].apply_gradients(cgrads[:half_cutoff]),
|
|
|
|
policy._critic_optimizer[1].apply_gradients(cgrads[half_cutoff:]),
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
critic_apply_ops = [policy._critic_optimizer[0].apply_gradients(cgrads)]
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
# Eager mode -> Just apply and return None.
|
2020-07-11 22:06:35 +02:00
|
|
|
if policy.config["framework"] in ["tf2", "tfe"]:
|
|
|
|
policy._alpha_optimizer.apply_gradients(policy._alpha_grads_and_vars)
|
|
|
|
return
|
2020-09-20 11:27:02 +02:00
|
|
|
# Tf static graph -> Return op.
|
2020-07-11 22:06:35 +02:00
|
|
|
else:
|
|
|
|
alpha_apply_ops = policy._alpha_optimizer.apply_gradients(
|
|
|
|
policy._alpha_grads_and_vars,
|
|
|
|
global_step=tf1.train.get_or_create_global_step(),
|
|
|
|
)
|
|
|
|
return tf.group([actor_apply_ops, alpha_apply_ops] + critic_apply_ops)
|
2019-12-20 10:51:25 -08:00
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
|
|
|
"""Stats function for SAC. Returns a dict with important loss stats.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy to generate stats for.
|
|
|
|
train_batch: The SampleBatch (already) used for training.
|
2020-09-20 11:27:02 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, TensorType]: The stats dict.
|
|
|
|
"""
|
2019-08-01 23:37:36 -07:00
|
|
|
return {
|
2020-04-15 13:25:16 +02:00
|
|
|
"mean_td_error": tf.reduce_mean(policy.td_error),
|
2019-08-01 23:37:36 -07:00
|
|
|
"actor_loss": tf.reduce_mean(policy.actor_loss),
|
|
|
|
"critic_loss": tf.reduce_mean(policy.critic_loss),
|
2020-03-06 19:37:12 +01:00
|
|
|
"alpha_loss": tf.reduce_mean(policy.alpha_loss),
|
|
|
|
"alpha_value": tf.reduce_mean(policy.alpha_value),
|
|
|
|
"target_entropy": tf.constant(policy.target_entropy),
|
2019-08-01 23:37:36 -07:00
|
|
|
"mean_q": tf.reduce_mean(policy.q_t),
|
|
|
|
"max_q": tf.reduce_max(policy.q_t),
|
|
|
|
"min_q": tf.reduce_min(policy.q_t),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-01-02 17:42:13 -08:00
|
|
|
class ActorCriticOptimizerMixin:
|
2020-09-20 11:27:02 +02:00
|
|
|
"""Mixin class to generate the necessary optimizers for actor-critic algos.
|
|
|
|
|
|
|
|
- Creates global step for counting the number of update operations.
|
|
|
|
- Creates separate optimizers for actor, critic, and alpha.
|
|
|
|
"""
|
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
def __init__(self, config):
|
2020-09-20 11:27:02 +02:00
|
|
|
# Eager mode.
|
2020-07-11 22:06:35 +02:00
|
|
|
if config["framework"] in ["tf2", "tfe"]:
|
|
|
|
self.global_step = get_variable(0, tf_name="global_step")
|
|
|
|
self._actor_optimizer = tf.keras.optimizers.Adam(
|
|
|
|
learning_rate=config["optimization"]["actor_learning_rate"]
|
|
|
|
)
|
|
|
|
self._critic_optimizer = [
|
2020-08-07 16:49:49 -07:00
|
|
|
tf.keras.optimizers.Adam(
|
|
|
|
learning_rate=config["optimization"]["critic_learning_rate"]
|
|
|
|
)
|
2020-07-11 22:06:35 +02:00
|
|
|
]
|
|
|
|
if config["twin_q"]:
|
|
|
|
self._critic_optimizer.append(
|
2020-08-07 16:49:49 -07:00
|
|
|
tf.keras.optimizers.Adam(
|
|
|
|
learning_rate=config["optimization"]["critic_learning_rate"]
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-07-11 22:06:35 +02:00
|
|
|
self._alpha_optimizer = tf.keras.optimizers.Adam(
|
|
|
|
learning_rate=config["optimization"]["entropy_learning_rate"]
|
|
|
|
)
|
2020-09-20 11:27:02 +02:00
|
|
|
# Static graph mode.
|
2020-07-11 22:06:35 +02:00
|
|
|
else:
|
|
|
|
self.global_step = tf1.train.get_or_create_global_step()
|
|
|
|
self._actor_optimizer = tf1.train.AdamOptimizer(
|
|
|
|
learning_rate=config["optimization"]["actor_learning_rate"]
|
|
|
|
)
|
|
|
|
self._critic_optimizer = [
|
2020-08-07 16:49:49 -07:00
|
|
|
tf1.train.AdamOptimizer(
|
|
|
|
learning_rate=config["optimization"]["critic_learning_rate"]
|
|
|
|
)
|
2020-07-11 22:06:35 +02:00
|
|
|
]
|
|
|
|
if config["twin_q"]:
|
|
|
|
self._critic_optimizer.append(
|
2020-08-07 16:49:49 -07:00
|
|
|
tf1.train.AdamOptimizer(
|
|
|
|
learning_rate=config["optimization"]["critic_learning_rate"]
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-07-11 22:06:35 +02:00
|
|
|
self._alpha_optimizer = tf1.train.AdamOptimizer(
|
|
|
|
learning_rate=config["optimization"]["entropy_learning_rate"]
|
|
|
|
)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def setup_early_mixins(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
2022-06-11 15:10:39 +02:00
|
|
|
config: AlgorithmConfigDict,
|
2020-09-20 11:27:02 +02:00
|
|
|
) -> None:
|
|
|
|
"""Call mixin classes' constructors before Policy's initialization.
|
|
|
|
|
|
|
|
Adds the necessary optimizers to the given Policy.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy object.
|
2020-09-20 11:27:02 +02:00
|
|
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
|
|
|
action_space (gym.spaces.Space): The Policy's action space.
|
2022-06-01 11:27:54 -07:00
|
|
|
config: The Policy's config.
|
2020-09-20 11:27:02 +02:00
|
|
|
"""
|
2019-08-01 23:37:36 -07:00
|
|
|
ActorCriticOptimizerMixin.__init__(policy, config)
|
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def setup_mid_mixins(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
2022-06-11 15:10:39 +02:00
|
|
|
config: AlgorithmConfigDict,
|
2020-09-20 11:27:02 +02:00
|
|
|
) -> None:
|
|
|
|
"""Call mixin classes' constructors before Policy's loss initialization.
|
|
|
|
|
|
|
|
Adds the `compute_td_error` method to the given policy.
|
|
|
|
Calling `compute_td_error` with batch data will re-calculate the loss
|
|
|
|
on that batch AND return the per-batch-item TD-error for prioritized
|
|
|
|
replay buffer record weight updating (in case a prioritized replay buffer
|
|
|
|
is used).
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy object.
|
2020-09-20 11:27:02 +02:00
|
|
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
|
|
|
action_space (gym.spaces.Space): The Policy's action space.
|
2022-06-01 11:27:54 -07:00
|
|
|
config: The Policy's config.
|
2020-09-20 11:27:02 +02:00
|
|
|
"""
|
2020-04-09 23:04:21 +02:00
|
|
|
ComputeTDErrorMixin.__init__(policy, sac_actor_critic_loss)
|
2019-08-23 02:21:11 -04:00
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def setup_late_mixins(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
2022-06-11 15:10:39 +02:00
|
|
|
config: AlgorithmConfigDict,
|
2020-09-20 11:27:02 +02:00
|
|
|
) -> None:
|
|
|
|
"""Call mixin classes' constructors after Policy initialization.
|
|
|
|
|
|
|
|
Adds the `update_target` method to the given policy.
|
|
|
|
Calling `update_target` updates all target Q-networks' weights from their
|
|
|
|
respective "main" Q-metworks, based on tau (smooth, partial updating).
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy object.
|
2020-09-20 11:27:02 +02:00
|
|
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
|
|
|
action_space (gym.spaces.Space): The Policy's action space.
|
2022-06-01 11:27:54 -07:00
|
|
|
config: The Policy's config.
|
2020-09-20 11:27:02 +02:00
|
|
|
"""
|
2019-08-01 23:37:36 -07:00
|
|
|
TargetNetworkMixin.__init__(policy, config)
|
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
def validate_spaces(
|
|
|
|
policy: Policy,
|
|
|
|
observation_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
2022-06-11 15:10:39 +02:00
|
|
|
config: AlgorithmConfigDict,
|
2020-09-20 11:27:02 +02:00
|
|
|
) -> None:
|
|
|
|
"""Validates the observation- and action spaces used for the Policy.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The policy, whose spaces are being validated.
|
2020-09-20 11:27:02 +02:00
|
|
|
observation_space (gym.spaces.Space): The observation space to
|
|
|
|
validate.
|
|
|
|
action_space (gym.spaces.Space): The action space to validate.
|
2022-06-01 11:27:54 -07:00
|
|
|
config: The Policy's config dict.
|
2020-09-20 11:27:02 +02:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
UnsupportedSpaceException: If one of the spaces is not supported.
|
|
|
|
"""
|
2021-02-11 11:36:53 +01:00
|
|
|
# Only support single Box or single Discrete spaces.
|
2020-11-11 18:45:28 +01:00
|
|
|
if not isinstance(action_space, (Box, Discrete, Simplex)):
|
2020-06-25 19:01:32 +02:00
|
|
|
raise UnsupportedSpaceException(
|
|
|
|
"Action space ({}) of {} is not supported for "
|
2020-11-11 18:45:28 +01:00
|
|
|
"SAC. Must be [Box|Discrete|Simplex].".format(action_space, policy)
|
|
|
|
)
|
2020-09-20 11:27:02 +02:00
|
|
|
# If Box, make sure it's a 1D vector space.
|
2020-11-11 18:45:28 +01:00
|
|
|
elif isinstance(action_space, (Box, Simplex)) and len(action_space.shape) > 1:
|
2020-06-25 19:01:32 +02:00
|
|
|
raise UnsupportedSpaceException(
|
|
|
|
"Action space ({}) of {} has multiple dimensions "
|
2020-09-20 11:27:02 +02:00
|
|
|
"{}. ".format(action_space, policy, action_space.shape)
|
2020-06-25 19:01:32 +02:00
|
|
|
+ "Consider reshaping this into a single dimension, "
|
|
|
|
"using a Tuple action space, or the multi-agent API."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
|
|
|
|
# above.
|
2019-08-01 23:37:36 -07:00
|
|
|
SACTFPolicy = build_tf_policy(
|
|
|
|
name="SACTFPolicy",
|
2022-05-19 09:30:42 -07:00
|
|
|
get_default_config=lambda: ray.rllib.algorithms.sac.sac.DEFAULT_CONFIG,
|
2019-08-01 23:37:36 -07:00
|
|
|
make_model=build_sac_model,
|
|
|
|
postprocess_fn=postprocess_trajectory,
|
2020-04-01 09:43:21 +02:00
|
|
|
action_distribution_fn=get_distribution_inputs_and_class,
|
2020-04-09 23:04:21 +02:00
|
|
|
loss_fn=sac_actor_critic_loss,
|
2019-08-01 23:37:36 -07:00
|
|
|
stats_fn=stats,
|
2021-05-18 11:10:46 +02:00
|
|
|
compute_gradients_fn=compute_and_clip_gradients,
|
2019-12-20 10:51:25 -08:00
|
|
|
apply_gradients_fn=apply_gradients,
|
2019-08-01 23:37:36 -07:00
|
|
|
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
2020-02-19 21:18:45 +01:00
|
|
|
mixins=[TargetNetworkMixin, ActorCriticOptimizerMixin, ComputeTDErrorMixin],
|
2020-06-25 19:01:32 +02:00
|
|
|
validate_spaces=validate_spaces,
|
2019-08-01 23:37:36 -07:00
|
|
|
before_init=setup_early_mixins,
|
2019-08-23 02:21:11 -04:00
|
|
|
before_loss_init=setup_mid_mixins,
|
2019-08-01 23:37:36 -07:00
|
|
|
after_init=setup_late_mixins,
|
2021-03-23 17:50:18 +01:00
|
|
|
)
|