2020-09-15 03:37:07 -07:00
|
|
|
"""TensorFlow policy class used for Simple Q-Learning"""
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2019-07-07 19:51:26 -07:00
|
|
|
import logging
|
2020-09-09 09:55:26 -07:00
|
|
|
from typing import List, Tuple, Type
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2020-09-09 09:55:26 -07:00
|
|
|
import gym
|
2019-07-03 15:59:47 -07:00
|
|
|
import ray
|
|
|
|
from ray.rllib.models import ModelCatalog
|
2020-09-09 09:55:26 -07:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
|
|
from ray.rllib.models.tf.tf_action_dist import Categorical, TFActionDistribution
|
2020-04-06 20:56:16 +02:00
|
|
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
2020-09-09 09:55:26 -07:00
|
|
|
from ray.rllib.policy import Policy
|
2020-09-15 03:37:07 -07:00
|
|
|
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
2020-09-09 09:55:26 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2022-05-25 05:38:03 -07:00
|
|
|
from ray.rllib.policy.tf_mixins import TargetNetworkMixin
|
2019-07-03 15:59:47 -07:00
|
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
2020-09-09 09:55:26 -07:00
|
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2022-05-25 05:38:03 -07:00
|
|
|
from ray.rllib.utils.tf_utils import huber_loss
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2019-07-07 19:51:26 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
Q_SCOPE = "q_func"
|
|
|
|
Q_TARGET_SCOPE = "target_q_func"
|
|
|
|
|
|
|
|
|
2020-09-15 03:37:07 -07:00
|
|
|
def build_q_models(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
2022-06-11 15:10:39 +02:00
|
|
|
config: AlgorithmConfigDict,
|
2020-09-09 09:55:26 -07:00
|
|
|
) -> ModelV2:
|
2021-08-18 17:21:01 +02:00
|
|
|
"""Build q_model and target_model for Simple Q learning
|
2020-09-15 03:37:07 -07:00
|
|
|
|
|
|
|
Note that this function works for both Tensorflow and PyTorch.
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2020-09-15 03:37:07 -07:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy, which will use the model for optimization.
|
2020-09-15 03:37:07 -07:00
|
|
|
obs_space (gym.spaces.Space): The policy's observation space.
|
|
|
|
action_space (gym.spaces.Space): The policy's action space.
|
2022-06-11 15:10:39 +02:00
|
|
|
config (AlgorithmConfigDict):
|
2020-09-15 03:37:07 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
ModelV2: The Model for the Policy to use.
|
|
|
|
Note: The target q model will not be returned, just assigned to
|
2021-08-18 17:21:01 +02:00
|
|
|
`policy.target_model`.
|
2020-09-15 03:37:07 -07:00
|
|
|
"""
|
2020-09-09 09:55:26 -07:00
|
|
|
if not isinstance(action_space, gym.spaces.Discrete):
|
2019-07-03 15:59:47 -07:00
|
|
|
raise UnsupportedSpaceException(
|
|
|
|
"Action space {} is not supported for DQN.".format(action_space)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2021-03-08 15:41:27 +01:00
|
|
|
model = ModelCatalog.get_model_v2(
|
2020-04-06 20:56:16 +02:00
|
|
|
obs_space=obs_space,
|
|
|
|
action_space=action_space,
|
|
|
|
num_outputs=action_space.n,
|
|
|
|
model_config=config["model"],
|
2020-05-27 16:19:13 +02:00
|
|
|
framework=config["framework"],
|
2020-04-06 20:56:16 +02:00
|
|
|
name=Q_SCOPE,
|
|
|
|
)
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2021-08-18 17:21:01 +02:00
|
|
|
policy.target_model = ModelCatalog.get_model_v2(
|
2020-04-06 20:56:16 +02:00
|
|
|
obs_space=obs_space,
|
|
|
|
action_space=action_space,
|
|
|
|
num_outputs=action_space.n,
|
|
|
|
model_config=config["model"],
|
2020-05-27 16:19:13 +02:00
|
|
|
framework=config["framework"],
|
2020-04-06 20:56:16 +02:00
|
|
|
name=Q_TARGET_SCOPE,
|
|
|
|
)
|
|
|
|
|
2021-03-08 15:41:27 +01:00
|
|
|
return model
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
|
2020-09-09 09:55:26 -07:00
|
|
|
def get_distribution_inputs_and_class(
|
|
|
|
policy: Policy,
|
|
|
|
q_model: ModelV2,
|
|
|
|
obs_batch: TensorType,
|
|
|
|
*,
|
|
|
|
explore=True,
|
|
|
|
is_training=True,
|
|
|
|
**kwargs
|
|
|
|
) -> Tuple[TensorType, type, List[TensorType]]:
|
2020-09-15 03:37:07 -07:00
|
|
|
"""Build the action distribution"""
|
2020-04-06 20:56:16 +02:00
|
|
|
q_vals = compute_q_values(policy, q_model, obs_batch, explore, is_training)
|
2020-02-22 23:19:49 +01:00
|
|
|
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
policy.q_values = q_vals
|
2020-05-27 16:19:13 +02:00
|
|
|
return (
|
|
|
|
policy.q_values,
|
|
|
|
(TorchCategorical if policy.config["framework"] == "torch" else Categorical),
|
|
|
|
[],
|
|
|
|
) # state-outs
|
2022-01-29 18:41:57 -08:00
|
|
|
|
|
|
|
|
2020-09-09 09:55:26 -07:00
|
|
|
def build_q_losses(
|
|
|
|
policy: Policy,
|
|
|
|
model: ModelV2,
|
|
|
|
dist_class: Type[TFActionDistribution],
|
|
|
|
train_batch: SampleBatch,
|
|
|
|
) -> TensorType:
|
2020-09-15 03:37:07 -07:00
|
|
|
"""Constructs the loss for SimpleQTFPolicy.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy to calculate the loss for.
|
2020-09-15 03:37:07 -07:00
|
|
|
model (ModelV2): The Model to calculate the loss for.
|
|
|
|
dist_class (Type[ActionDistribution]): The action distribution class.
|
2022-06-01 11:27:54 -07:00
|
|
|
train_batch: The training data.
|
2020-09-15 03:37:07 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
TensorType: A single loss tensor.
|
|
|
|
"""
|
2019-07-03 15:59:47 -07:00
|
|
|
# q network evaluation
|
2020-04-01 09:43:21 +02:00
|
|
|
q_t = compute_q_values(
|
2021-03-08 15:41:27 +01:00
|
|
|
policy, policy.model, train_batch[SampleBatch.CUR_OBS], explore=False
|
|
|
|
)
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
# target q network evalution
|
2020-04-01 09:43:21 +02:00
|
|
|
q_tp1 = compute_q_values(
|
|
|
|
policy, policy.target_model, train_batch[SampleBatch.NEXT_OBS], explore=False
|
|
|
|
)
|
2021-08-18 17:21:01 +02:00
|
|
|
if not hasattr(policy, "q_func_vars"):
|
|
|
|
policy.q_func_vars = model.variables()
|
|
|
|
policy.target_q_func_vars = policy.target_model.variables()
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
# q scores for actions which we know were selected in the given state.
|
|
|
|
one_hot_selection = tf.one_hot(
|
2019-08-23 02:21:11 -04:00
|
|
|
tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32), policy.action_space.n
|
2019-07-03 15:59:47 -07:00
|
|
|
)
|
|
|
|
q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
|
|
|
|
|
|
|
|
# compute estimate of best possible value starting from state at t + 1
|
2019-08-23 02:21:11 -04:00
|
|
|
dones = tf.cast(train_batch[SampleBatch.DONES], tf.float32)
|
2019-07-03 15:59:47 -07:00
|
|
|
q_tp1_best_one_hot_selection = tf.one_hot(
|
|
|
|
tf.argmax(q_tp1, 1), policy.action_space.n
|
|
|
|
)
|
|
|
|
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
|
|
|
q_tp1_best_masked = (1.0 - dones) * q_tp1_best
|
|
|
|
|
|
|
|
# compute RHS of bellman equation
|
2019-08-23 02:21:11 -04:00
|
|
|
q_t_selected_target = (
|
2019-07-03 15:59:47 -07:00
|
|
|
train_batch[SampleBatch.REWARDS] + policy.config["gamma"] * q_tp1_best_masked
|
|
|
|
)
|
|
|
|
|
|
|
|
# compute the error (potentially clipped)
|
|
|
|
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
|
|
|
loss = tf.reduce_mean(huber_loss(td_error))
|
|
|
|
|
|
|
|
# save TD error as an attribute for outside access
|
|
|
|
policy.td_error = td_error
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
2020-09-09 09:55:26 -07:00
|
|
|
def compute_q_values(
|
|
|
|
policy: Policy, model: ModelV2, obs: TensorType, explore, is_training=None
|
|
|
|
) -> TensorType:
|
2021-12-02 04:11:26 -08:00
|
|
|
_is_training = (
|
|
|
|
is_training
|
|
|
|
if is_training is not None
|
|
|
|
else policy._get_is_training_placeholder()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-12-02 04:11:26 -08:00
|
|
|
model_out, _ = model(SampleBatch(obs=obs, _is_training=_is_training), [], None)
|
2020-04-01 09:43:21 +02:00
|
|
|
|
2020-04-06 20:56:16 +02:00
|
|
|
return model_out
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
|
2020-09-15 03:37:07 -07:00
|
|
|
def setup_late_mixins(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
2022-06-11 15:10:39 +02:00
|
|
|
config: AlgorithmConfigDict,
|
2020-09-09 09:55:26 -07:00
|
|
|
) -> None:
|
2020-09-15 03:37:07 -07:00
|
|
|
"""Call all mixin classes' constructors before SimpleQTFPolicy initialization.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
policy: The Policy object.
|
2020-09-15 03:37:07 -07:00
|
|
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
|
|
|
action_space (gym.spaces.Space): The Policy's action space.
|
2022-06-01 11:27:54 -07:00
|
|
|
config: The Policy's config.
|
2020-09-15 03:37:07 -07:00
|
|
|
"""
|
2019-07-03 15:59:47 -07:00
|
|
|
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
|
|
|
|
|
|
|
|
2020-09-15 03:37:07 -07:00
|
|
|
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
|
|
|
|
# above.
|
2021-08-18 17:21:01 +02:00
|
|
|
SimpleQTFPolicy: Type[DynamicTFPolicy] = build_tf_policy(
|
2020-04-06 20:56:16 +02:00
|
|
|
name="SimpleQTFPolicy",
|
2022-06-04 07:35:24 +02:00
|
|
|
get_default_config=lambda: ray.rllib.algorithms.simple_q.simple_q.DEFAULT_CONFIG,
|
2019-07-03 15:59:47 -07:00
|
|
|
make_model=build_q_models,
|
2020-04-01 09:43:21 +02:00
|
|
|
action_distribution_fn=get_distribution_inputs_and_class,
|
2019-07-03 15:59:47 -07:00
|
|
|
loss_fn=build_q_losses,
|
2021-02-25 12:18:11 +01:00
|
|
|
extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
|
2019-07-03 15:59:47 -07:00
|
|
|
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
|
|
|
after_init=setup_late_mixins,
|
2020-04-03 19:44:25 +02:00
|
|
|
mixins=[TargetNetworkMixin],
|
|
|
|
)
|