2019-05-20 16:46:05 -07:00
|
|
|
import errno
|
2020-07-05 13:09:51 +02:00
|
|
|
import gym
|
2019-05-20 16:46:05 -07:00
|
|
|
import logging
|
2021-08-02 17:29:59 -04:00
|
|
|
import math
|
2020-04-01 09:43:21 +02:00
|
|
|
import numpy as np
|
2019-07-09 03:30:32 +02:00
|
|
|
import os
|
2021-09-23 12:56:45 +02:00
|
|
|
import tree # pip install dm_tree
|
2021-05-03 14:23:28 -07:00
|
|
|
from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
import ray
|
|
|
|
import ray.experimental.tf_utils
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.util.debug import log_once
|
2021-09-30 16:39:05 +02:00
|
|
|
from ray.rllib.policy.policy import Policy
|
2020-04-01 07:00:28 +02:00
|
|
|
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2019-07-03 15:59:47 -07:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2021-09-21 22:00:14 +02:00
|
|
|
from ray.rllib.utils import force_list
|
2021-11-01 21:46:02 +01:00
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI, override
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.rllib.utils.debug import summarize
|
2021-11-01 21:46:02 +01:00
|
|
|
from ray.rllib.utils.deprecation import Deprecated, deprecation_warning
|
2022-05-17 08:16:08 -07:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2022-02-08 16:43:00 +01:00
|
|
|
from ray.rllib.utils.metrics import NUM_AGENT_STEPS_TRAINED
|
2021-09-30 16:39:05 +02:00
|
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
2021-07-13 20:01:30 +02:00
|
|
|
from ray.rllib.utils.spaces.space_utils import normalize_action
|
2021-11-01 21:46:02 +01:00
|
|
|
from ray.rllib.utils.tf_utils import get_gpu_devices
|
2022-05-24 22:14:25 -07:00
|
|
|
from ray.rllib.utils.tf_run_builder import _TFRunBuilder
|
2021-09-21 22:00:14 +02:00
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
LocalOptimizer,
|
|
|
|
ModelGradients,
|
|
|
|
TensorType,
|
|
|
|
TrainerConfigDict,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-05-03 14:23:28 -07:00
|
|
|
if TYPE_CHECKING:
|
2021-10-29 12:03:56 +02:00
|
|
|
from ray.rllib.evaluation import Episode
|
2021-05-03 14:23:28 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2019-05-20 16:46:05 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
class TFPolicy(Policy):
|
|
|
|
"""An agent policy and loss implemented in TensorFlow.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Do not sub-class this class directly (neither should you sub-class
|
|
|
|
DynamicTFPolicy), but rather use
|
|
|
|
rllib.policy.tf_policy_template.build_tf_policy
|
|
|
|
to generate your custom tf (graph-mode or eager) Policy classes.
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
Extending this class enables RLlib to perform TensorFlow specific
|
|
|
|
optimizations on the policy, e.g., parallelization across gpus or
|
|
|
|
fusing multiple graphs together in the multi-agent setting.
|
|
|
|
|
|
|
|
Input tensors are typically shaped like [BATCH_SIZE, ...].
|
|
|
|
|
|
|
|
Examples:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> from ray.rllib.policy import TFPolicy
|
|
|
|
>>> class TFPolicySubclass(TFPolicy): # doctest: +SKIP
|
|
|
|
... ... # doctest: +SKIP
|
|
|
|
>>> sess, obs_input, sampled_action, loss, loss_inputs = ... # doctest: +SKIP
|
|
|
|
>>> policy = TFPolicySubclass( # doctest: +SKIP
|
|
|
|
... sess, obs_input, sampled_action, loss, loss_inputs) # doctest: +SKIP
|
|
|
|
>>> print(policy.compute_actions([1, 0, 2])) # doctest: +SKIP
|
2019-05-20 16:46:05 -07:00
|
|
|
(array([0, 1, 1]), [], {})
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> print(policy.postprocess_trajectory(SampleBatch({...}))) # doctest: +SKIP
|
2019-05-20 16:46:05 -07:00
|
|
|
SampleBatch({"action": ..., "advantages": ..., ...})
|
|
|
|
"""
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def __init__(
|
|
|
|
self,
|
2020-07-05 13:09:51 +02:00
|
|
|
observation_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
|
|
|
sess: "tf1.Session",
|
|
|
|
obs_input: TensorType,
|
|
|
|
sampled_action: TensorType,
|
2021-09-21 22:00:14 +02:00
|
|
|
loss: Union[TensorType, List[TensorType]],
|
2020-07-05 13:09:51 +02:00
|
|
|
loss_inputs: List[Tuple[str, TensorType]],
|
2021-10-27 19:14:39 +02:00
|
|
|
model: Optional[ModelV2] = None,
|
2020-07-05 13:09:51 +02:00
|
|
|
sampled_action_logp: Optional[TensorType] = None,
|
|
|
|
action_input: Optional[TensorType] = None,
|
|
|
|
log_likelihood: Optional[TensorType] = None,
|
|
|
|
dist_inputs: Optional[TensorType] = None,
|
|
|
|
dist_class: Optional[type] = None,
|
|
|
|
state_inputs: Optional[List[TensorType]] = None,
|
|
|
|
state_outputs: Optional[List[TensorType]] = None,
|
|
|
|
prev_action_input: Optional[TensorType] = None,
|
|
|
|
prev_reward_input: Optional[TensorType] = None,
|
|
|
|
seq_lens: Optional[TensorType] = None,
|
|
|
|
max_seq_len: int = 20,
|
|
|
|
batch_divisibility_req: int = 1,
|
|
|
|
update_ops: List[TensorType] = None,
|
|
|
|
explore: Optional[TensorType] = None,
|
|
|
|
timestep: Optional[TensorType] = None,
|
|
|
|
):
|
|
|
|
"""Initializes a Policy object.
|
|
|
|
|
|
|
|
Args:
|
2021-10-27 19:14:39 +02:00
|
|
|
observation_space: Observation space of the policy.
|
|
|
|
action_space: Action space of the policy.
|
|
|
|
config: Policy-specific configuration data.
|
|
|
|
sess: The TensorFlow session to use.
|
|
|
|
obs_input: Input placeholder for observations, of shape
|
|
|
|
[BATCH_SIZE, obs...].
|
|
|
|
sampled_action: Tensor for sampling an action, of shape
|
|
|
|
[BATCH_SIZE, action...]
|
|
|
|
loss: Scalar policy loss output tensor or a list thereof
|
|
|
|
(in case there is more than one loss).
|
|
|
|
loss_inputs: A (name, placeholder) tuple for each loss input
|
|
|
|
argument. Each placeholder name must
|
2020-07-05 13:09:51 +02:00
|
|
|
correspond to a SampleBatch column key returned by
|
|
|
|
postprocess_trajectory(), and has shape [BATCH_SIZE, data...].
|
|
|
|
These keys will be read from postprocessed sample batches and
|
|
|
|
fed into the specified placeholders during loss computation.
|
2021-10-27 19:14:39 +02:00
|
|
|
model: The optional ModelV2 to use for calculating actions and
|
|
|
|
losses. If not None, TFPolicy will provide functionality for
|
|
|
|
getting variables, calling the model's custom loss (if
|
|
|
|
provided), and importing weights into the model.
|
|
|
|
sampled_action_logp: log probability of the sampled action.
|
|
|
|
action_input: Input placeholder for actions for
|
|
|
|
logp/log-likelihood calculations.
|
|
|
|
log_likelihood: Tensor to calculate the log_likelihood (given
|
|
|
|
action_input and obs_input).
|
|
|
|
dist_class: An optional ActionDistribution class to use for
|
|
|
|
generating a dist object from distribution inputs.
|
|
|
|
dist_inputs: Tensor to calculate the distribution
|
|
|
|
inputs/parameters.
|
|
|
|
state_inputs: List of RNN state input Tensors.
|
|
|
|
state_outputs: List of RNN state output Tensors.
|
|
|
|
prev_action_input: placeholder for previous actions.
|
|
|
|
prev_reward_input: placeholder for previous rewards.
|
|
|
|
seq_lens: Placeholder for RNN sequence lengths, of shape
|
|
|
|
[NUM_SEQUENCES].
|
2020-07-05 13:09:51 +02:00
|
|
|
Note that NUM_SEQUENCES << BATCH_SIZE. See
|
2019-07-27 02:08:16 -07:00
|
|
|
policy/rnn_sequencing.py for more information.
|
2021-10-27 19:14:39 +02:00
|
|
|
max_seq_len: Max sequence length for LSTM training.
|
|
|
|
batch_divisibility_req: pad all agent experiences batches to
|
2019-05-20 16:46:05 -07:00
|
|
|
multiples of this value. This only has an effect if not using
|
|
|
|
a LSTM model.
|
2021-10-27 19:14:39 +02:00
|
|
|
update_ops: override the batchnorm update ops
|
2021-07-20 14:58:13 -04:00
|
|
|
to run when applying gradients. Otherwise we run all update
|
|
|
|
ops found in the current variable scope.
|
2021-10-27 19:14:39 +02:00
|
|
|
explore: Placeholder for `explore` parameter into call to
|
2021-07-20 14:58:13 -04:00
|
|
|
Exploration.get_exploration_action. Explicitly set this to
|
|
|
|
False for not creating any Exploration component.
|
2021-10-27 19:14:39 +02:00
|
|
|
timestep: Placeholder for the global sampling timestep.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2020-02-19 21:18:45 +01:00
|
|
|
self.framework = "tf"
|
2020-02-11 00:22:07 +01:00
|
|
|
super().__init__(observation_space, action_space, config)
|
2021-04-20 08:46:05 +02:00
|
|
|
|
2021-07-20 14:58:13 -04:00
|
|
|
# Get devices to build the graph on.
|
2022-05-20 05:10:59 -07:00
|
|
|
worker_idx = config.get("worker_index", 0)
|
2021-08-02 17:29:59 -04:00
|
|
|
if not config["_fake_gpus"] and ray.worker._mode() == ray.worker.LOCAL_MODE:
|
|
|
|
num_gpus = 0
|
|
|
|
elif worker_idx == 0:
|
|
|
|
num_gpus = config["num_gpus"]
|
|
|
|
else:
|
|
|
|
num_gpus = config["num_gpus_per_worker"]
|
|
|
|
gpu_ids = get_gpu_devices()
|
|
|
|
|
|
|
|
# Place on one or more CPU(s) when either:
|
|
|
|
# - Fake GPU mode.
|
|
|
|
# - num_gpus=0 (either set by user or we are in local_mode=True).
|
|
|
|
# - no GPUs available.
|
|
|
|
if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
|
2021-07-20 14:58:13 -04:00
|
|
|
logger.info(
|
|
|
|
"TFPolicy (worker={}) running on {}.".format(
|
|
|
|
worker_idx if worker_idx > 0 else "local",
|
|
|
|
f"{num_gpus} fake-GPUs" if config["_fake_gpus"] else "CPU",
|
|
|
|
)
|
2021-08-02 17:29:59 -04:00
|
|
|
)
|
|
|
|
self.devices = ["/cpu:0" for _ in range(int(math.ceil(num_gpus)) or 1)]
|
|
|
|
# Place on one or more actual GPU(s), when:
|
|
|
|
# - num_gpus > 0 (set by user) AND
|
|
|
|
# - local_mode=False AND
|
|
|
|
# - actual GPUs available AND
|
|
|
|
# - non-fake GPU mode.
|
2021-07-20 14:58:13 -04:00
|
|
|
else:
|
|
|
|
logger.info(
|
|
|
|
"TFPolicy (worker={}) running on {} GPU(s).".format(
|
|
|
|
worker_idx if worker_idx > 0 else "local", num_gpus
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-08-02 17:29:59 -04:00
|
|
|
|
|
|
|
# We are a remote worker (WORKER_MODE=1):
|
|
|
|
# GPUs should be assigned to us by ray.
|
|
|
|
if ray.worker._mode() == ray.worker.WORKER_MODE:
|
|
|
|
gpu_ids = ray.get_gpu_ids()
|
|
|
|
|
|
|
|
if len(gpu_ids) < num_gpus:
|
|
|
|
raise ValueError(
|
|
|
|
"TFPolicy was not able to find enough GPU IDs! Found "
|
|
|
|
f"{gpu_ids}, but num_gpus={num_gpus}."
|
|
|
|
)
|
|
|
|
|
2021-07-20 14:58:13 -04:00
|
|
|
self.devices = [f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus]
|
2021-04-20 08:46:05 +02:00
|
|
|
|
2020-11-12 16:27:34 +01:00
|
|
|
# Disable env-info placeholder.
|
|
|
|
if SampleBatch.INFOS in self.view_requirements:
|
2021-04-11 18:20:04 +02:00
|
|
|
self.view_requirements[SampleBatch.INFOS].used_for_compute_actions = False
|
2022-04-25 11:17:50 +02:00
|
|
|
self.view_requirements[SampleBatch.INFOS].used_for_training = False
|
|
|
|
# Optionally add `infos` to the output dataset
|
|
|
|
if self.config["output_config"].get("store_infos", False):
|
|
|
|
self.view_requirements[SampleBatch.INFOS].used_for_training = True
|
2020-10-27 10:00:24 +01:00
|
|
|
|
2021-04-27 10:44:54 +02:00
|
|
|
assert model is None or isinstance(model, (ModelV2, tf.keras.Model)), (
|
|
|
|
"Model classes for TFPolicy other than `ModelV2|tf.keras.Model` "
|
|
|
|
"not allowed! You passed in {}.".format(model)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
self.model = model
|
2020-11-12 16:27:34 +01:00
|
|
|
# Auto-update model's inference view requirements, if recurrent.
|
|
|
|
if self.model is not None:
|
2020-12-30 20:32:21 -05:00
|
|
|
self._update_model_view_requirements_from_init_state()
|
2020-10-27 10:00:24 +01:00
|
|
|
|
2021-07-20 14:58:13 -04:00
|
|
|
# If `explore` is explicitly set to False, don't create an exploration
|
|
|
|
# component.
|
|
|
|
self.exploration = self._create_exploration() if explore is not False else None
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
self._sess = sess
|
|
|
|
self._obs_input = obs_input
|
|
|
|
self._prev_action_input = prev_action_input
|
|
|
|
self._prev_reward_input = prev_reward_input
|
2020-02-22 23:19:49 +01:00
|
|
|
self._sampled_action = sampled_action
|
2019-05-20 16:46:05 -07:00
|
|
|
self._is_training = self._get_is_training_placeholder()
|
2020-02-19 21:18:45 +01:00
|
|
|
self._is_exploring = (
|
|
|
|
explore
|
|
|
|
if explore is not None
|
2020-06-30 10:13:20 +02:00
|
|
|
else tf1.placeholder_with_default(True, (), name="is_exploring")
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-02-22 23:19:49 +01:00
|
|
|
self._sampled_action_logp = sampled_action_logp
|
2020-06-30 10:13:20 +02:00
|
|
|
self._sampled_action_prob = (
|
|
|
|
tf.math.exp(self._sampled_action_logp)
|
2020-02-22 23:19:49 +01:00
|
|
|
if self._sampled_action_logp is not None
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
self._action_input = action_input # For logp calculations.
|
2020-04-03 19:44:25 +02:00
|
|
|
self._dist_inputs = dist_inputs
|
2020-04-01 09:43:21 +02:00
|
|
|
self.dist_class = dist_class
|
2022-05-20 05:10:59 -07:00
|
|
|
self._cached_extra_action_out = None
|
2019-05-20 16:46:05 -07:00
|
|
|
self._state_inputs = state_inputs or []
|
|
|
|
self._state_outputs = state_outputs or []
|
|
|
|
self._seq_lens = seq_lens
|
|
|
|
self._max_seq_len = max_seq_len
|
2021-07-20 14:58:13 -04:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
if self._state_inputs and self._seq_lens is None:
|
|
|
|
raise ValueError(
|
|
|
|
"seq_lens tensor must be given if state inputs are defined"
|
|
|
|
)
|
|
|
|
|
2020-06-05 21:07:02 +02:00
|
|
|
self._batch_divisibility_req = batch_divisibility_req
|
|
|
|
self._update_ops = update_ops
|
|
|
|
self._apply_op = None
|
|
|
|
self._stats_fetches = {}
|
|
|
|
self._timestep = (
|
|
|
|
timestep
|
|
|
|
if timestep is not None
|
2020-12-11 16:13:38 +01:00
|
|
|
else tf1.placeholder_with_default(
|
|
|
|
tf.zeros((), dtype=tf.int64), (), name="timestep"
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-06-05 21:07:02 +02:00
|
|
|
|
2021-09-21 22:00:14 +02:00
|
|
|
self._optimizers: List[LocalOptimizer] = []
|
|
|
|
# Backward compatibility and for some code shared with tf-eager Policy.
|
2020-06-05 21:07:02 +02:00
|
|
|
self._optimizer = None
|
2021-09-21 22:00:14 +02:00
|
|
|
|
|
|
|
self._grads_and_vars: Union[ModelGradients, List[ModelGradients]] = []
|
|
|
|
self._grads: Union[ModelGradients, List[ModelGradients]] = []
|
2020-06-05 21:07:02 +02:00
|
|
|
# Policy tf-variables (weights), whose values to get/set via
|
|
|
|
# get_weights/set_weights.
|
|
|
|
self._variables = None
|
2021-09-21 22:00:14 +02:00
|
|
|
# Local optimizer(s)' tf-variables (e.g. state vars for Adam).
|
2020-06-05 21:07:02 +02:00
|
|
|
# Will be stored alongside `self._variables` when checkpointing.
|
2021-09-21 22:00:14 +02:00
|
|
|
self._optimizer_variables: Optional[
|
|
|
|
ray.experimental.tf_utils.TensorFlowVariables
|
|
|
|
] = None
|
2020-06-05 21:07:02 +02:00
|
|
|
|
2021-09-21 22:00:14 +02:00
|
|
|
# The loss tf-op(s). Number of losses must match number of optimizers.
|
|
|
|
self._losses = []
|
|
|
|
# Backward compatibility (in case custom child TFPolicies access this
|
|
|
|
# property).
|
2020-06-05 21:07:02 +02:00
|
|
|
self._loss = None
|
|
|
|
# A batch dict passed into loss function as input.
|
2020-11-12 16:27:34 +01:00
|
|
|
self._loss_input_dict = {}
|
2021-09-21 22:00:14 +02:00
|
|
|
losses = force_list(loss)
|
|
|
|
if len(losses) > 0:
|
|
|
|
self._initialize_loss(losses, loss_inputs)
|
2020-06-05 21:07:02 +02:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
# The log-likelihood calculator op.
|
2020-06-05 21:07:02 +02:00
|
|
|
self._log_likelihood = log_likelihood
|
|
|
|
if (
|
|
|
|
self._log_likelihood is None
|
|
|
|
and self._dist_inputs is not None
|
|
|
|
and self.dist_class is not None
|
|
|
|
):
|
2020-04-03 19:44:25 +02:00
|
|
|
self._log_likelihood = self.dist_class(self._dist_inputs, self.model).logp(
|
|
|
|
self._action_input
|
|
|
|
)
|
2020-02-22 23:19:49 +01:00
|
|
|
|
2021-07-19 13:16:03 -04:00
|
|
|
@override(Policy)
|
2021-10-27 19:14:39 +02:00
|
|
|
def compute_actions_from_input_dict(
|
|
|
|
self,
|
|
|
|
input_dict: Union[SampleBatch, Dict[str, TensorType]],
|
|
|
|
explore: bool = None,
|
|
|
|
timestep: Optional[int] = None,
|
2021-10-29 12:03:56 +02:00
|
|
|
episodes: Optional[List["Episode"]] = None,
|
2021-10-27 19:14:39 +02:00
|
|
|
**kwargs,
|
|
|
|
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2021-10-27 19:14:39 +02:00
|
|
|
explore = explore if explore is not None else self.config["explore"]
|
|
|
|
timestep = timestep if timestep is not None else self.global_timestep
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-11-02 12:10:17 +01:00
|
|
|
# Switch off is_training flag in our batch.
|
2021-12-02 04:11:26 -08:00
|
|
|
if isinstance(input_dict, SampleBatch):
|
|
|
|
input_dict.set_training(False)
|
|
|
|
else:
|
|
|
|
# Deprecated dict input.
|
|
|
|
input_dict["is_training"] = False
|
2021-11-02 12:10:17 +01:00
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
builder = _TFRunBuilder(self.get_session(), "compute_actions_from_input_dict")
|
2021-10-27 19:14:39 +02:00
|
|
|
obs_batch = input_dict[SampleBatch.OBS]
|
|
|
|
to_fetch = self._build_compute_actions(
|
|
|
|
builder, input_dict=input_dict, explore=explore, timestep=timestep
|
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-10-27 19:14:39 +02:00
|
|
|
# Execute session run to get action (and other fetches).
|
|
|
|
fetched = builder.get(to_fetch)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-10-27 19:14:39 +02:00
|
|
|
# Update our global timestep by the batch size.
|
|
|
|
self.global_timestep += (
|
|
|
|
len(obs_batch)
|
|
|
|
if isinstance(obs_batch, list)
|
|
|
|
else len(input_dict)
|
|
|
|
if isinstance(input_dict, SampleBatch)
|
|
|
|
else obs_batch.shape[0]
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-09-21 22:00:14 +02:00
|
|
|
|
2021-10-27 19:14:39 +02:00
|
|
|
return fetched
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
def compute_actions(
|
|
|
|
self,
|
|
|
|
obs_batch: Union[List[TensorType], TensorType],
|
|
|
|
state_batches: Optional[List[TensorType]] = None,
|
|
|
|
prev_action_batch: Union[List[TensorType], TensorType] = None,
|
|
|
|
prev_reward_batch: Union[List[TensorType], TensorType] = None,
|
|
|
|
info_batch: Optional[Dict[str, list]] = None,
|
2021-10-29 12:03:56 +02:00
|
|
|
episodes: Optional[List["Episode"]] = None,
|
2020-07-05 13:09:51 +02:00
|
|
|
explore: Optional[bool] = None,
|
|
|
|
timestep: Optional[int] = None,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
explore = explore if explore is not None else self.config["explore"]
|
2020-04-01 09:43:21 +02:00
|
|
|
timestep = timestep if timestep is not None else self.global_timestep
|
2020-03-29 00:16:30 +01:00
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
builder = _TFRunBuilder(self.get_session(), "compute_actions")
|
2021-09-30 16:39:38 +02:00
|
|
|
|
2021-11-02 12:10:17 +01:00
|
|
|
input_dict = {SampleBatch.OBS: obs_batch, "is_training": False}
|
2021-09-30 16:39:38 +02:00
|
|
|
if state_batches:
|
|
|
|
for i, s in enumerate(state_batches):
|
|
|
|
input_dict[f"state_in_{i}"] = s
|
|
|
|
if prev_action_batch is not None:
|
|
|
|
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
|
|
|
if prev_reward_batch is not None:
|
|
|
|
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
to_fetch = self._build_compute_actions(
|
2021-09-30 16:39:38 +02:00
|
|
|
builder, input_dict=input_dict, explore=explore, timestep=timestep
|
|
|
|
)
|
2020-04-01 09:43:21 +02:00
|
|
|
|
2020-02-11 00:22:07 +01:00
|
|
|
# Execute session run to get action (and other fetches).
|
2020-04-01 09:43:21 +02:00
|
|
|
fetched = builder.get(to_fetch)
|
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
# Update our global timestep by the batch size.
|
2021-09-23 12:56:45 +02:00
|
|
|
self.global_timestep += (
|
|
|
|
len(obs_batch)
|
|
|
|
if isinstance(obs_batch, list)
|
|
|
|
else tree.flatten(obs_batch)[0].shape[0]
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-10-06 20:28:16 +02:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
return fetched
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-02-22 23:19:49 +01:00
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
def compute_log_likelihoods(
|
|
|
|
self,
|
|
|
|
actions: Union[List[TensorType], TensorType],
|
|
|
|
obs_batch: Union[List[TensorType], TensorType],
|
|
|
|
state_batches: Optional[List[TensorType]] = None,
|
2020-08-07 16:49:49 -07:00
|
|
|
prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
|
2021-07-13 20:01:30 +02:00
|
|
|
prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
|
|
|
|
actions_normalized: bool = True,
|
|
|
|
) -> TensorType:
|
2020-07-05 13:09:51 +02:00
|
|
|
|
2020-02-22 23:19:49 +01:00
|
|
|
if self._log_likelihood is None:
|
|
|
|
raise ValueError(
|
2022-03-15 09:34:21 -07:00
|
|
|
"Cannot compute log-prob/likelihood w/o a self._log_likelihood op!"
|
2020-02-22 23:19:49 +01:00
|
|
|
)
|
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
# Exploration hook before each forward pass.
|
|
|
|
self.exploration.before_compute_actions(
|
|
|
|
explore=False, tf_sess=self.get_session()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-01 09:43:21 +02:00
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
builder = _TFRunBuilder(self.get_session(), "compute_log_likelihoods")
|
2021-07-13 20:01:30 +02:00
|
|
|
|
|
|
|
# Normalize actions if necessary.
|
|
|
|
if actions_normalized is False and self.config["normalize_actions"]:
|
|
|
|
actions = normalize_action(actions, self.action_space_struct)
|
|
|
|
|
2020-02-22 23:19:49 +01:00
|
|
|
# Feed actions (for which we want logp values) into graph.
|
|
|
|
builder.add_feed_dict({self._action_input: actions})
|
|
|
|
# Feed observations.
|
|
|
|
builder.add_feed_dict({self._obs_input: obs_batch})
|
|
|
|
# Internal states.
|
|
|
|
state_batches = state_batches or []
|
|
|
|
if len(self._state_inputs) != len(state_batches):
|
|
|
|
raise ValueError(
|
|
|
|
"Must pass in RNN state batches for placeholders {}, got {}".format(
|
|
|
|
self._state_inputs, state_batches
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-02-22 23:19:49 +01:00
|
|
|
builder.add_feed_dict({k: v for k, v in zip(self._state_inputs, state_batches)})
|
|
|
|
if state_batches:
|
|
|
|
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
|
|
|
|
# Prev-a and r.
|
|
|
|
if self._prev_action_input is not None and prev_action_batch is not None:
|
|
|
|
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
|
|
|
|
if self._prev_reward_input is not None and prev_reward_batch is not None:
|
|
|
|
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
|
|
|
|
# Fetch the log_likelihoods output and return.
|
|
|
|
fetches = builder.add_fetches([self._log_likelihood])
|
|
|
|
return builder.get(fetches)[0]
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
@DeveloperAPI
|
2020-08-07 16:49:49 -07:00
|
|
|
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
2020-07-05 13:09:51 +02:00
|
|
|
assert self.loss_initialized()
|
2021-02-08 15:02:19 +01:00
|
|
|
|
2021-11-02 12:10:17 +01:00
|
|
|
# Switch on is_training flag in our batch.
|
2021-11-05 16:10:00 +01:00
|
|
|
postprocessed_batch.set_training(True)
|
2021-11-02 12:10:17 +01:00
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
builder = _TFRunBuilder(self.get_session(), "learn_on_batch")
|
2021-02-08 15:02:19 +01:00
|
|
|
|
|
|
|
# Callback handling.
|
|
|
|
learn_stats = {}
|
|
|
|
self.callbacks.on_learn_on_batch(
|
|
|
|
policy=self, train_batch=postprocessed_batch, result=learn_stats
|
|
|
|
)
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
|
2021-02-08 15:02:19 +01:00
|
|
|
stats = builder.get(fetches)
|
2022-02-08 16:43:00 +01:00
|
|
|
stats.update(
|
|
|
|
{
|
|
|
|
"custom_metrics": learn_stats,
|
|
|
|
NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
|
|
|
|
}
|
|
|
|
)
|
2021-02-08 15:02:19 +01:00
|
|
|
return stats
|
2020-07-05 13:09:51 +02:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
@DeveloperAPI
|
|
|
|
def compute_gradients(
|
|
|
|
self, postprocessed_batch: SampleBatch
|
|
|
|
) -> Tuple[ModelGradients, Dict[str, TensorType]]:
|
2019-06-02 14:14:31 +08:00
|
|
|
assert self.loss_initialized()
|
2021-11-02 12:10:17 +01:00
|
|
|
# Switch on is_training flag in our batch.
|
2021-11-05 16:10:00 +01:00
|
|
|
postprocessed_batch.set_training(True)
|
2022-05-24 22:14:25 -07:00
|
|
|
builder = _TFRunBuilder(self.get_session(), "compute_gradients")
|
2019-05-20 16:46:05 -07:00
|
|
|
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
|
|
|
return builder.get(fetches)
|
|
|
|
|
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
@DeveloperAPI
|
|
|
|
def apply_gradients(self, gradients: ModelGradients) -> None:
|
2019-06-02 14:14:31 +08:00
|
|
|
assert self.loss_initialized()
|
2022-05-24 22:14:25 -07:00
|
|
|
builder = _TFRunBuilder(self.get_session(), "apply_gradients")
|
2019-05-20 16:46:05 -07:00
|
|
|
fetches = self._build_apply_gradients(builder, gradients)
|
|
|
|
builder.get(fetches)
|
|
|
|
|
2021-10-27 19:14:39 +02:00
|
|
|
@override(Policy)
|
|
|
|
@DeveloperAPI
|
|
|
|
def get_weights(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
|
|
|
return self._variables.get_weights()
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
@DeveloperAPI
|
|
|
|
def set_weights(self, weights) -> None:
|
|
|
|
return self._variables.set_weights(weights)
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
@DeveloperAPI
|
2021-06-15 13:08:43 +02:00
|
|
|
def get_exploration_state(self) -> Dict[str, TensorType]:
|
|
|
|
return self.exploration.get_state(sess=self.get_session())
|
|
|
|
|
2021-08-03 18:30:02 -04:00
|
|
|
@Deprecated(new="get_exploration_state", error=False)
|
2020-07-05 13:09:51 +02:00
|
|
|
def get_exploration_info(self) -> Dict[str, TensorType]:
|
2021-06-15 13:08:43 +02:00
|
|
|
return self.get_exploration_state()
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
@DeveloperAPI
|
2021-10-27 19:14:39 +02:00
|
|
|
def is_recurrent(self) -> bool:
|
|
|
|
return len(self._state_inputs) > 0
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
@DeveloperAPI
|
2021-10-27 19:14:39 +02:00
|
|
|
def num_state_tensors(self) -> int:
|
|
|
|
return len(self._state_inputs)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-06-05 21:07:02 +02:00
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
@DeveloperAPI
|
|
|
|
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
2020-06-05 21:07:02 +02:00
|
|
|
# For tf Policies, return Policy weights and optimizer var values.
|
|
|
|
state = super().get_state()
|
2021-09-21 22:00:14 +02:00
|
|
|
if len(self._optimizer_variables.variables) > 0:
|
2021-07-19 13:16:03 -04:00
|
|
|
state["_optimizer_variables"] = self.get_session().run(
|
|
|
|
self._optimizer_variables.variables
|
|
|
|
)
|
2021-06-15 13:08:43 +02:00
|
|
|
# Add exploration state.
|
|
|
|
state["_exploration_state"] = self.exploration.get_state(self.get_session())
|
2020-06-05 21:07:02 +02:00
|
|
|
return state
|
|
|
|
|
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
@DeveloperAPI
|
2021-06-15 13:08:43 +02:00
|
|
|
def set_state(self, state: dict) -> None:
|
2020-06-05 21:07:02 +02:00
|
|
|
# Set optimizer vars first.
|
2021-06-15 13:08:43 +02:00
|
|
|
optimizer_vars = state.get("_optimizer_variables", None)
|
2021-09-21 22:00:14 +02:00
|
|
|
if optimizer_vars is not None:
|
2020-06-05 21:07:02 +02:00
|
|
|
self._optimizer_variables.set_weights(optimizer_vars)
|
2021-06-15 13:08:43 +02:00
|
|
|
# Set exploration's state.
|
|
|
|
if hasattr(self, "exploration") and "_exploration_state" in state:
|
|
|
|
self.exploration.set_state(
|
|
|
|
state=state["_exploration_state"], sess=self.get_session()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-06-15 13:08:43 +02:00
|
|
|
|
|
|
|
# Set the Policy's (NN) weights.
|
2020-06-05 21:07:02 +02:00
|
|
|
super().set_state(state)
|
|
|
|
|
2021-10-27 19:14:39 +02:00
|
|
|
@override(Policy)
|
|
|
|
@DeveloperAPI
|
|
|
|
def export_checkpoint(
|
|
|
|
self, export_dir: str, filename_prefix: str = "model"
|
|
|
|
) -> None:
|
|
|
|
"""Export tensorflow checkpoint to export_dir."""
|
|
|
|
try:
|
|
|
|
os.makedirs(export_dir)
|
|
|
|
except OSError as e:
|
|
|
|
# ignore error if export dir already exists
|
|
|
|
if e.errno != errno.EEXIST:
|
|
|
|
raise
|
|
|
|
save_path = os.path.join(export_dir, filename_prefix)
|
|
|
|
with self.get_session().graph.as_default():
|
|
|
|
saver = tf1.train.Saver()
|
|
|
|
saver.save(self.get_session(), save_path)
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
@DeveloperAPI
|
2021-07-13 18:38:11 +02:00
|
|
|
def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Export tensorflow graph to export_dir for serving."""
|
2021-07-13 18:38:11 +02:00
|
|
|
if onnx:
|
|
|
|
try:
|
|
|
|
import tf2onnx
|
|
|
|
except ImportError as e:
|
|
|
|
raise RuntimeError(
|
|
|
|
"Converting a TensorFlow model to ONNX requires "
|
|
|
|
"`tf2onnx` to be installed. Install with "
|
|
|
|
"`pip install tf2onnx`."
|
|
|
|
) from e
|
|
|
|
|
2021-07-19 13:16:03 -04:00
|
|
|
with self.get_session().graph.as_default():
|
2021-07-13 18:38:11 +02:00
|
|
|
signature_def_map = self._build_signature_def()
|
|
|
|
|
|
|
|
sd = signature_def_map[
|
|
|
|
tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501
|
|
|
|
]
|
|
|
|
inputs = [v.name for k, v in sd.inputs.items()]
|
|
|
|
outputs = [v.name for k, v in sd.outputs.items()]
|
|
|
|
|
|
|
|
from tf2onnx import tf_loader
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-07-13 18:38:11 +02:00
|
|
|
frozen_graph_def = tf_loader.freeze_session(
|
|
|
|
self._sess, input_names=inputs, output_names=outputs
|
|
|
|
)
|
|
|
|
|
2021-07-19 13:16:03 -04:00
|
|
|
with tf1.Session(graph=tf.Graph()) as session:
|
2021-07-13 18:38:11 +02:00
|
|
|
tf.import_graph_def(frozen_graph_def, name="")
|
|
|
|
|
|
|
|
g = tf2onnx.tfonnx.process_tf_graph(
|
|
|
|
session.graph,
|
|
|
|
input_names=inputs,
|
|
|
|
output_names=outputs,
|
|
|
|
inputs_as_nchw=inputs,
|
|
|
|
)
|
|
|
|
|
|
|
|
model_proto = g.make_model("onnx_model")
|
|
|
|
tf2onnx.utils.save_onnx_model(
|
|
|
|
export_dir, "saved_model", feed_dict={}, model_proto=model_proto
|
|
|
|
)
|
|
|
|
else:
|
2021-07-19 13:16:03 -04:00
|
|
|
with self.get_session().graph.as_default():
|
2021-07-13 18:38:11 +02:00
|
|
|
signature_def_map = self._build_signature_def()
|
|
|
|
builder = tf1.saved_model.builder.SavedModelBuilder(export_dir)
|
|
|
|
builder.add_meta_graph_and_variables(
|
2021-07-19 13:16:03 -04:00
|
|
|
self.get_session(),
|
|
|
|
[tf1.saved_model.tag_constants.SERVING],
|
2021-07-13 18:38:11 +02:00
|
|
|
signature_def_map=signature_def_map,
|
|
|
|
saver=tf1.summary.FileWriter(export_dir).add_graph(
|
2021-07-19 13:16:03 -04:00
|
|
|
graph=self.get_session().graph
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2021-07-19 13:16:03 -04:00
|
|
|
)
|
2021-07-13 18:38:11 +02:00
|
|
|
builder.save()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-03-23 20:19:30 +01:00
|
|
|
@override(Policy)
|
2020-07-05 13:09:51 +02:00
|
|
|
@DeveloperAPI
|
|
|
|
def import_model_from_h5(self, import_file: str) -> None:
|
2020-03-23 20:19:30 +01:00
|
|
|
"""Imports weights into tf model."""
|
2021-10-27 19:14:39 +02:00
|
|
|
if self.model is None:
|
|
|
|
raise NotImplementedError("No `self.model` to import into!")
|
|
|
|
|
2020-03-23 20:19:30 +01:00
|
|
|
# Make sure the session is the right one (see issue #7046).
|
2021-07-19 13:16:03 -04:00
|
|
|
with self.get_session().graph.as_default():
|
|
|
|
with self.get_session().as_default():
|
2020-03-23 20:19:30 +01:00
|
|
|
return self.model.import_from_h5(import_file)
|
|
|
|
|
2021-10-27 19:14:39 +02:00
|
|
|
@override(Policy)
|
|
|
|
def get_session(self) -> Optional["tf1.Session"]:
|
|
|
|
"""Returns a reference to the TF session for this policy."""
|
|
|
|
return self._sess
|
|
|
|
|
|
|
|
def variables(self):
|
|
|
|
"""Return the list of all savable variables for this policy."""
|
|
|
|
if self.model is None:
|
|
|
|
raise NotImplementedError("No `self.model` to get variables for!")
|
|
|
|
elif isinstance(self.model, tf.keras.Model):
|
|
|
|
return self.model.variables
|
|
|
|
else:
|
|
|
|
return self.model.variables()
|
|
|
|
|
|
|
|
def get_placeholder(self, name) -> "tf1.placeholder":
|
|
|
|
"""Returns the given action or loss input placeholder by name.
|
|
|
|
|
|
|
|
If the loss has not been initialized and a loss input placeholder is
|
|
|
|
requested, an error is raised.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name (str): The name of the placeholder to return. One of
|
|
|
|
SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from
|
|
|
|
`self._loss_input_dict`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
tf1.placeholder: The placeholder under the given str key.
|
|
|
|
"""
|
|
|
|
if name == SampleBatch.CUR_OBS:
|
|
|
|
return self._obs_input
|
|
|
|
elif name == SampleBatch.PREV_ACTIONS:
|
|
|
|
return self._prev_action_input
|
|
|
|
elif name == SampleBatch.PREV_REWARDS:
|
|
|
|
return self._prev_reward_input
|
|
|
|
|
|
|
|
assert self._loss_input_dict, (
|
|
|
|
"You need to populate `self._loss_input_dict` before "
|
|
|
|
"`get_placeholder()` can be called"
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-10-27 19:14:39 +02:00
|
|
|
return self._loss_input_dict[name]
|
|
|
|
|
|
|
|
def loss_initialized(self) -> bool:
|
|
|
|
"""Returns whether the loss term(s) have been initialized."""
|
|
|
|
return len(self._losses) > 0
|
|
|
|
|
|
|
|
def _initialize_loss(
|
|
|
|
self, losses: List[TensorType], loss_inputs: List[Tuple[str, TensorType]]
|
|
|
|
) -> None:
|
|
|
|
"""Initializes the loss op from given loss tensor and placeholders.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
loss (List[TensorType]): The list of loss ops returned by some
|
|
|
|
loss function.
|
|
|
|
loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples:
|
|
|
|
(name, tf1.placeholders) needed for calculating the loss.
|
|
|
|
"""
|
|
|
|
self._loss_input_dict = dict(loss_inputs)
|
|
|
|
self._loss_input_dict_no_rnn = {
|
|
|
|
k: v
|
|
|
|
for k, v in self._loss_input_dict.items()
|
|
|
|
if (v not in self._state_inputs and v != self._seq_lens)
|
|
|
|
}
|
|
|
|
for i, ph in enumerate(self._state_inputs):
|
|
|
|
self._loss_input_dict["state_in_{}".format(i)] = ph
|
|
|
|
|
|
|
|
if self.model and not isinstance(self.model, tf.keras.Model):
|
|
|
|
self._losses = force_list(
|
|
|
|
self.model.custom_loss(losses, self._loss_input_dict)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-10-27 19:14:39 +02:00
|
|
|
self._stats_fetches.update({"model": self.model.metrics()})
|
|
|
|
else:
|
|
|
|
self._losses = losses
|
|
|
|
# Backward compatibility.
|
|
|
|
self._loss = self._losses[0] if self._losses is not None else None
|
|
|
|
|
|
|
|
if not self._optimizers:
|
|
|
|
self._optimizers = force_list(self.optimizer())
|
|
|
|
# Backward compatibility.
|
|
|
|
self._optimizer = self._optimizers[0] if self._optimizers else None
|
|
|
|
|
|
|
|
# Supporting more than one loss/optimizer.
|
|
|
|
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
|
|
|
self._grads_and_vars = []
|
|
|
|
self._grads = []
|
|
|
|
for group in self.gradients(self._optimizers, self._losses):
|
|
|
|
g_and_v = [(g, v) for (g, v) in group if g is not None]
|
|
|
|
self._grads_and_vars.append(g_and_v)
|
|
|
|
self._grads.append([g for (g, _) in g_and_v])
|
|
|
|
# Only one optimizer and and loss term.
|
|
|
|
else:
|
|
|
|
self._grads_and_vars = [
|
|
|
|
(g, v)
|
|
|
|
for (g, v) in self.gradients(self._optimizer, self._loss)
|
|
|
|
if g is not None
|
|
|
|
]
|
|
|
|
self._grads = [g for (g, _) in self._grads_and_vars]
|
|
|
|
|
|
|
|
if self.model:
|
|
|
|
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
|
|
|
|
[], self.get_session(), self.variables()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-10-27 19:14:39 +02:00
|
|
|
|
|
|
|
# Gather update ops for any batch norm layers.
|
|
|
|
if len(self.devices) <= 1:
|
|
|
|
if not self._update_ops:
|
|
|
|
self._update_ops = tf1.get_collection(
|
|
|
|
tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
|
|
|
|
)
|
|
|
|
if self._update_ops:
|
|
|
|
logger.info(
|
|
|
|
"Update ops to run on apply gradient: {}".format(self._update_ops)
|
|
|
|
)
|
|
|
|
with tf1.control_dependencies(self._update_ops):
|
|
|
|
self._apply_op = self.build_apply_op(
|
|
|
|
optimizer=self._optimizers
|
|
|
|
if self.config["_tf_policy_handles_more_than_one_loss"]
|
|
|
|
else self._optimizer,
|
|
|
|
grads_and_vars=self._grads_and_vars,
|
|
|
|
)
|
|
|
|
|
|
|
|
if log_once("loss_used"):
|
|
|
|
logger.debug(
|
|
|
|
"These tensors were used in the loss functions:"
|
|
|
|
f"\n{summarize(self._loss_input_dict)}\n"
|
|
|
|
)
|
|
|
|
|
|
|
|
self.get_session().run(tf1.global_variables_initializer())
|
|
|
|
|
|
|
|
# TensorFlowVariables holing a flat list of all our optimizers'
|
|
|
|
# variables.
|
|
|
|
self._optimizer_variables = ray.experimental.tf_utils.TensorFlowVariables(
|
|
|
|
[v for o in self._optimizers for v in o.variables()], self.get_session()
|
|
|
|
)
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> "TFPolicy":
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Creates a copy of self using existing input placeholders.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Optional: Only required to work with the multi-GPU optimizer.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
existing_inputs (List[Tuple[str, tf1.placeholder]]): Dict mapping
|
|
|
|
names (str) to tf1.placeholders to re-use (share) with the
|
|
|
|
returned copy of self.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
TFPolicy: A copy of self.
|
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def extra_compute_action_feed_dict(self) -> Dict[TensorType, TensorType]:
|
|
|
|
"""Extra dict to pass to the compute actions session run.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[TensorType, TensorType]: A feed dict to be added to the
|
|
|
|
feed_dict passed to the compute_actions session.run() call.
|
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
return {}
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def extra_compute_action_fetches(self) -> Dict[str, TensorType]:
|
2022-05-20 05:10:59 -07:00
|
|
|
# Cache graph fetches for action computation for better
|
|
|
|
# performance.
|
|
|
|
# This function is called every time the static graph is run
|
|
|
|
# to compute actions.
|
|
|
|
if not self._cached_extra_action_out:
|
|
|
|
self._cached_extra_action_out = self.extra_action_out_fn()
|
|
|
|
return self._cached_extra_action_out
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def extra_action_out_fn(self) -> Dict[str, TensorType]:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Extra values to fetch and return from compute_actions().
|
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
By default we return action probability/log-likelihood info
|
|
|
|
and action distribution inputs (if present).
|
2020-07-05 13:09:51 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, TensorType]: An extra fetch-dict to be passed to and
|
|
|
|
returned from the compute_actions() call.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2020-04-01 09:43:21 +02:00
|
|
|
extra_fetches = {}
|
|
|
|
# Action-logp and action-prob.
|
2020-02-22 23:19:49 +01:00
|
|
|
if self._sampled_action_logp is not None:
|
2020-04-01 09:43:21 +02:00
|
|
|
extra_fetches[SampleBatch.ACTION_PROB] = self._sampled_action_prob
|
|
|
|
extra_fetches[SampleBatch.ACTION_LOGP] = self._sampled_action_logp
|
|
|
|
# Action-dist inputs.
|
2020-04-03 19:44:25 +02:00
|
|
|
if self._dist_inputs is not None:
|
|
|
|
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = self._dist_inputs
|
2020-04-01 09:43:21 +02:00
|
|
|
return extra_fetches
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def extra_compute_grad_feed_dict(self) -> Dict[TensorType, TensorType]:
|
|
|
|
"""Extra dict to pass to the compute gradients session run.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[TensorType, TensorType]: Extra feed_dict to be passed to the
|
|
|
|
compute_gradients Session.run() call.
|
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
return {} # e.g, kl_coeff
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def extra_compute_grad_fetches(self) -> Dict[str, any]:
|
|
|
|
"""Extra values to fetch and return from compute_gradients().
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, any]: Extra fetch dict to be added to the fetch dict
|
|
|
|
of the compute_gradients Session.run() call.
|
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def optimizer(self) -> "tf.keras.optimizers.Optimizer":
|
|
|
|
"""TF optimizer to use for policy optimization.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
tf.keras.optimizers.Optimizer: The local optimizer to use for this
|
|
|
|
Policy's Model.
|
|
|
|
"""
|
2021-09-21 22:00:14 +02:00
|
|
|
if hasattr(self, "config") and "lr" in self.config:
|
2020-06-30 10:13:20 +02:00
|
|
|
return tf1.train.AdamOptimizer(learning_rate=self.config["lr"])
|
2019-05-20 16:46:05 -07:00
|
|
|
else:
|
2020-06-30 10:13:20 +02:00
|
|
|
return tf1.train.AdamOptimizer()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2021-09-21 22:00:14 +02:00
|
|
|
def gradients(
|
|
|
|
self,
|
|
|
|
optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
|
|
|
|
loss: Union[TensorType, List[TensorType]],
|
|
|
|
) -> Union[List[ModelGradients], List[List[ModelGradients]]]:
|
2020-07-05 13:09:51 +02:00
|
|
|
"""Override this for a custom gradient computation behavior.
|
|
|
|
|
2021-09-21 22:00:14 +02:00
|
|
|
Args:
|
|
|
|
optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): A single
|
|
|
|
LocalOptimizer of a list thereof to use for gradient
|
|
|
|
calculations. If more than one optimizer given, the number of
|
|
|
|
optimizers must match the number of losses provided.
|
|
|
|
loss (Union[TensorType, List[TensorType]]): A single loss term
|
|
|
|
or a list thereof to use for gradient calculations.
|
|
|
|
If more than one loss given, the number of loss terms must
|
|
|
|
match the number of optimizers provided.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Returns:
|
2021-09-21 22:00:14 +02:00
|
|
|
Union[List[ModelGradients], List[List[ModelGradients]]]: List of
|
|
|
|
ModelGradients (grads and vars OR just grads) OR List of List
|
|
|
|
of ModelGradients in case we have more than one
|
|
|
|
optimizer/loss.
|
2020-07-05 13:09:51 +02:00
|
|
|
"""
|
2021-09-21 22:00:14 +02:00
|
|
|
optimizers = force_list(optimizer)
|
|
|
|
losses = force_list(loss)
|
|
|
|
|
|
|
|
# We have more than one optimizers and loss terms.
|
|
|
|
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
|
|
|
grads = []
|
|
|
|
for optim, loss_ in zip(optimizers, losses):
|
|
|
|
grads.append(optim.compute_gradients(loss_))
|
|
|
|
# We have only one optimizer and one loss term.
|
|
|
|
else:
|
|
|
|
return optimizers[0].compute_gradients(losses[0])
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def build_apply_op(
|
|
|
|
self,
|
2021-09-21 22:00:14 +02:00
|
|
|
optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
|
|
|
|
grads_and_vars: Union[ModelGradients, List[ModelGradients]],
|
|
|
|
) -> "tf.Operation":
|
2020-07-05 13:09:51 +02:00
|
|
|
"""Override this for a custom gradient apply computation behavior.
|
|
|
|
|
|
|
|
Args:
|
2021-09-21 22:00:14 +02:00
|
|
|
optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): The local
|
|
|
|
tf optimizer to use for applying the grads and vars.
|
|
|
|
grads_and_vars (Union[ModelGradients, List[ModelGradients]]): List
|
|
|
|
of tuples with grad values and the grad-value's corresponding
|
2020-07-05 13:09:51 +02:00
|
|
|
tf.variable in it.
|
2021-09-21 22:00:14 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
tf.Operation: The tf op that applies all computed gradients
|
|
|
|
(`grads_and_vars`) to the model(s) via the given optimizer(s).
|
2020-07-05 13:09:51 +02:00
|
|
|
"""
|
2021-09-21 22:00:14 +02:00
|
|
|
optimizers = force_list(optimizer)
|
|
|
|
|
|
|
|
# We have more than one optimizers and loss terms.
|
|
|
|
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
|
|
|
ops = []
|
|
|
|
for i, optim in enumerate(optimizers):
|
|
|
|
# Specify global_step (e.g. for TD3 which needs to count the
|
|
|
|
# num updates that have happened).
|
|
|
|
ops.append(
|
|
|
|
optim.apply_gradients(
|
|
|
|
grads_and_vars[i],
|
|
|
|
global_step=tf1.train.get_or_create_global_step(),
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-09-21 22:00:14 +02:00
|
|
|
return tf.group(ops)
|
|
|
|
# We have only one optimizer and one loss term.
|
|
|
|
else:
|
|
|
|
return optimizers[0].apply_gradients(
|
|
|
|
grads_and_vars, global_step=tf1.train.get_or_create_global_step()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
def _get_is_training_placeholder(self):
|
|
|
|
"""Get the placeholder for _is_training, i.e., for batch norm layers.
|
|
|
|
|
|
|
|
This can be called safely before __init__ has run.
|
|
|
|
"""
|
|
|
|
if not hasattr(self, "_is_training"):
|
2020-06-30 10:13:20 +02:00
|
|
|
self._is_training = tf1.placeholder_with_default(
|
2020-02-11 00:22:07 +01:00
|
|
|
False, (), name="is_training"
|
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
return self._is_training
|
|
|
|
|
2019-07-07 19:51:26 -07:00
|
|
|
def _debug_vars(self):
|
|
|
|
if log_once("grad_vars"):
|
2021-09-21 22:00:14 +02:00
|
|
|
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
|
|
|
for group in self._grads_and_vars:
|
|
|
|
for _, v in group:
|
|
|
|
logger.info("Optimizing variable {}".format(v))
|
|
|
|
else:
|
|
|
|
for _, v in self._grads_and_vars:
|
|
|
|
logger.info("Optimizing variable {}".format(v))
|
2019-07-07 19:51:26 -07:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
def _extra_input_signature_def(self):
|
|
|
|
"""Extra input signatures to add when exporting tf model.
|
|
|
|
Inferred from extra_compute_action_feed_dict()
|
|
|
|
"""
|
|
|
|
feed_dict = self.extra_compute_action_feed_dict()
|
|
|
|
return {
|
2020-06-30 10:13:20 +02:00
|
|
|
k.name: tf1.saved_model.utils.build_tensor_info(k) for k in feed_dict.keys()
|
2019-05-20 16:46:05 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
def _extra_output_signature_def(self):
|
|
|
|
"""Extra output signatures to add when exporting tf model.
|
|
|
|
Inferred from extra_compute_action_fetches()
|
|
|
|
"""
|
|
|
|
fetches = self.extra_compute_action_fetches()
|
|
|
|
return {
|
2020-06-30 10:13:20 +02:00
|
|
|
k: tf1.saved_model.utils.build_tensor_info(fetches[k])
|
2019-05-20 16:46:05 -07:00
|
|
|
for k in fetches.keys()
|
|
|
|
}
|
|
|
|
|
|
|
|
def _build_signature_def(self):
|
|
|
|
"""Build signature def map for tensorflow SavedModelBuilder."""
|
|
|
|
# build input signatures
|
|
|
|
input_signature = self._extra_input_signature_def()
|
2020-06-30 10:13:20 +02:00
|
|
|
input_signature["observations"] = tf1.saved_model.utils.build_tensor_info(
|
|
|
|
self._obs_input
|
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
if self._seq_lens is not None:
|
2021-08-21 17:05:48 +02:00
|
|
|
input_signature[
|
|
|
|
SampleBatch.SEQ_LENS
|
2020-06-30 10:13:20 +02:00
|
|
|
] = tf1.saved_model.utils.build_tensor_info(self._seq_lens)
|
2019-05-20 16:46:05 -07:00
|
|
|
if self._prev_action_input is not None:
|
2020-06-30 10:13:20 +02:00
|
|
|
input_signature["prev_action"] = tf1.saved_model.utils.build_tensor_info(
|
|
|
|
self._prev_action_input
|
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
if self._prev_reward_input is not None:
|
2020-06-30 10:13:20 +02:00
|
|
|
input_signature["prev_reward"] = tf1.saved_model.utils.build_tensor_info(
|
|
|
|
self._prev_reward_input
|
|
|
|
)
|
2021-02-10 15:21:46 +01:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
input_signature["is_training"] = tf1.saved_model.utils.build_tensor_info(
|
|
|
|
self._is_training
|
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-02-10 15:21:46 +01:00
|
|
|
if self._timestep is not None:
|
|
|
|
input_signature["timestep"] = tf1.saved_model.utils.build_tensor_info(
|
|
|
|
self._timestep
|
|
|
|
)
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
for state_input in self._state_inputs:
|
2020-06-30 10:13:20 +02:00
|
|
|
input_signature[state_input.name] = tf1.saved_model.utils.build_tensor_info(
|
2019-05-20 16:46:05 -07:00
|
|
|
state_input
|
2020-06-30 10:13:20 +02:00
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
# build output signatures
|
|
|
|
output_signature = self._extra_output_signature_def()
|
2020-03-22 21:51:24 +01:00
|
|
|
for i, a in enumerate(tf.nest.flatten(self._sampled_action)):
|
2020-02-27 00:22:54 +01:00
|
|
|
output_signature[
|
|
|
|
"actions_{}".format(i)
|
2020-06-30 10:13:20 +02:00
|
|
|
] = tf1.saved_model.utils.build_tensor_info(a)
|
2020-02-27 00:22:54 +01:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
for state_output in self._state_outputs:
|
|
|
|
output_signature[
|
|
|
|
state_output.name
|
2020-06-30 10:13:20 +02:00
|
|
|
] = tf1.saved_model.utils.build_tensor_info(state_output)
|
|
|
|
signature_def = tf1.saved_model.signature_def_utils.build_signature_def(
|
2019-05-20 16:46:05 -07:00
|
|
|
input_signature,
|
|
|
|
output_signature,
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1.saved_model.signature_constants.PREDICT_METHOD_NAME,
|
|
|
|
)
|
|
|
|
signature_def_key = (
|
|
|
|
tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
2019-05-20 16:46:05 -07:00
|
|
|
)
|
|
|
|
signature_def_map = {signature_def_key: signature_def}
|
|
|
|
return signature_def_map
|
|
|
|
|
|
|
|
def _build_compute_actions(
|
|
|
|
self,
|
|
|
|
builder,
|
2020-04-01 09:43:21 +02:00
|
|
|
*,
|
2021-01-08 10:56:09 +01:00
|
|
|
input_dict=None,
|
|
|
|
obs_batch=None,
|
2019-05-20 16:46:05 -07:00
|
|
|
state_batches=None,
|
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
2020-02-11 00:22:07 +01:00
|
|
|
episodes=None,
|
2020-02-19 21:18:45 +01:00
|
|
|
explore=None,
|
2020-02-11 00:22:07 +01:00
|
|
|
timestep=None,
|
|
|
|
):
|
2020-02-19 21:18:45 +01:00
|
|
|
explore = explore if explore is not None else self.config["explore"]
|
2020-04-01 09:43:21 +02:00
|
|
|
timestep = timestep if timestep is not None else self.global_timestep
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-03-29 00:16:30 +01:00
|
|
|
# Call the exploration before_compute_actions hook.
|
|
|
|
self.exploration.before_compute_actions(
|
2020-04-01 09:43:21 +02:00
|
|
|
timestep=timestep, explore=explore, tf_sess=self.get_session()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-29 00:16:30 +01:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
builder.add_feed_dict(self.extra_compute_action_feed_dict())
|
2021-01-08 10:56:09 +01:00
|
|
|
|
|
|
|
# `input_dict` given: Simply build what's in that dict.
|
|
|
|
if input_dict is not None:
|
|
|
|
if hasattr(self, "_input_dict"):
|
|
|
|
for key, value in input_dict.items():
|
|
|
|
if key in self._input_dict:
|
2021-09-23 12:56:45 +02:00
|
|
|
# Handle complex/nested spaces as well.
|
|
|
|
tree.map_structure(
|
|
|
|
lambda k, v: builder.add_feed_dict({k: v}),
|
|
|
|
self._input_dict[key],
|
|
|
|
value,
|
|
|
|
)
|
2021-01-08 10:56:09 +01:00
|
|
|
# For policies that inherit directly from TFPolicy.
|
|
|
|
else:
|
|
|
|
builder.add_feed_dict({self._obs_input: input_dict[SampleBatch.OBS]})
|
|
|
|
if SampleBatch.PREV_ACTIONS in input_dict:
|
|
|
|
builder.add_feed_dict(
|
|
|
|
{self._prev_action_input: input_dict[SampleBatch.PREV_ACTIONS]}
|
|
|
|
)
|
|
|
|
if SampleBatch.PREV_REWARDS in input_dict:
|
|
|
|
builder.add_feed_dict(
|
|
|
|
{self._prev_reward_input: input_dict[SampleBatch.PREV_REWARDS]}
|
|
|
|
)
|
|
|
|
state_batches = []
|
|
|
|
i = 0
|
|
|
|
while "state_in_{}".format(i) in input_dict:
|
|
|
|
state_batches.append(input_dict["state_in_{}".format(i)])
|
|
|
|
i += 1
|
|
|
|
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
|
|
|
|
2021-12-11 14:57:58 +01:00
|
|
|
if "state_in_0" in input_dict and SampleBatch.SEQ_LENS not in input_dict:
|
2021-01-08 10:56:09 +01:00
|
|
|
builder.add_feed_dict(
|
|
|
|
{self._seq_lens: np.ones(len(input_dict["state_in_0"]))}
|
|
|
|
)
|
|
|
|
|
|
|
|
# Hardcoded old way: Build fixed fields, if provided.
|
|
|
|
# TODO: (sven) This can be deprecated after trajectory view API flag is
|
|
|
|
# removed and always True.
|
|
|
|
else:
|
2021-09-30 16:39:38 +02:00
|
|
|
if log_once("_build_compute_actions_input_dict"):
|
|
|
|
deprecation_warning(
|
|
|
|
old="_build_compute_actions(.., obs_batch=.., ..)",
|
|
|
|
new="_build_compute_actions(.., input_dict=..)",
|
|
|
|
error=False,
|
|
|
|
)
|
2021-01-08 10:56:09 +01:00
|
|
|
state_batches = state_batches or []
|
|
|
|
if len(self._state_inputs) != len(state_batches):
|
|
|
|
raise ValueError(
|
|
|
|
"Must pass in RNN state batches for placeholders {}, "
|
|
|
|
"got {}".format(self._state_inputs, state_batches)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-01-08 10:56:09 +01:00
|
|
|
|
2021-09-23 12:56:45 +02:00
|
|
|
tree.map_structure(
|
|
|
|
lambda k, v: builder.add_feed_dict({k: v}),
|
|
|
|
self._obs_input,
|
|
|
|
obs_batch,
|
|
|
|
)
|
2021-01-08 10:56:09 +01:00
|
|
|
if state_batches:
|
|
|
|
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
|
|
|
|
if self._prev_action_input is not None and prev_action_batch is not None:
|
|
|
|
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
|
|
|
|
if self._prev_reward_input is not None and prev_reward_batch is not None:
|
|
|
|
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
|
|
|
|
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
builder.add_feed_dict({self._is_training: False})
|
2020-02-11 00:22:07 +01:00
|
|
|
builder.add_feed_dict({self._is_exploring: explore})
|
|
|
|
if timestep is not None:
|
|
|
|
builder.add_feed_dict({self._timestep: timestep})
|
2020-04-03 19:44:25 +02:00
|
|
|
|
|
|
|
# Determine, what exactly to fetch from the graph.
|
|
|
|
to_fetch = (
|
|
|
|
[self._sampled_action]
|
|
|
|
+ self._state_outputs
|
|
|
|
+ [self.extra_compute_action_fetches()]
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-03 19:44:25 +02:00
|
|
|
|
|
|
|
# Perform the session call.
|
|
|
|
fetches = builder.add_fetches(to_fetch)
|
2019-05-20 16:46:05 -07:00
|
|
|
return fetches[0], fetches[1:-1], fetches[-1]
|
|
|
|
|
|
|
|
def _build_compute_gradients(self, builder, postprocessed_batch):
|
2019-07-07 19:51:26 -07:00
|
|
|
self._debug_vars()
|
2019-05-20 16:46:05 -07:00
|
|
|
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
2019-07-06 20:40:49 -07:00
|
|
|
builder.add_feed_dict(
|
|
|
|
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
fetches = builder.add_fetches([self._grads, self._get_grad_and_stats_fetches()])
|
|
|
|
return fetches[0], fetches[1]
|
|
|
|
|
|
|
|
def _build_apply_gradients(self, builder, gradients):
|
|
|
|
if len(gradients) != len(self._grads):
|
|
|
|
raise ValueError(
|
|
|
|
"Unexpected number of gradients to apply, got {} for {}".format(
|
|
|
|
gradients, self._grads
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
builder.add_feed_dict({self._is_training: True})
|
|
|
|
builder.add_feed_dict(dict(zip(self._grads, gradients)))
|
|
|
|
fetches = builder.add_fetches([self._apply_op])
|
|
|
|
return fetches[0]
|
|
|
|
|
|
|
|
def _build_learn_on_batch(self, builder, postprocessed_batch):
|
2019-07-07 19:51:26 -07:00
|
|
|
self._debug_vars()
|
2020-11-18 15:39:23 +01:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
2019-07-06 20:40:49 -07:00
|
|
|
builder.add_feed_dict(
|
|
|
|
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
fetches = builder.add_fetches(
|
|
|
|
[
|
|
|
|
self._apply_op,
|
|
|
|
self._get_grad_and_stats_fetches(),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
return fetches[1]
|
|
|
|
|
|
|
|
def _get_grad_and_stats_fetches(self):
|
|
|
|
fetches = self.extra_compute_grad_fetches()
|
|
|
|
if LEARNER_STATS_KEY not in fetches:
|
|
|
|
raise ValueError("Grad fetches should contain 'stats': {...} entry")
|
|
|
|
if self._stats_fetches:
|
|
|
|
fetches[LEARNER_STATS_KEY] = dict(
|
|
|
|
self._stats_fetches, **fetches[LEARNER_STATS_KEY]
|
|
|
|
)
|
|
|
|
return fetches
|
|
|
|
|
2021-03-17 08:18:15 +01:00
|
|
|
def _get_loss_inputs_dict(self, train_batch: SampleBatch, shuffle: bool):
|
2020-04-01 09:43:21 +02:00
|
|
|
"""Return a feed dict from a batch.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
2020-12-21 02:22:32 +01:00
|
|
|
train_batch (SampleBatch): batch of data to derive inputs from.
|
2020-04-01 09:43:21 +02:00
|
|
|
shuffle (bool): whether to shuffle batch sequences. Shuffle may
|
|
|
|
be done in-place. This only makes sense if you're further
|
|
|
|
applying minibatch SGD after getting the outputs.
|
|
|
|
|
|
|
|
Returns:
|
2021-03-17 08:18:15 +01:00
|
|
|
Feed dict of data.
|
2020-04-01 09:43:21 +02:00
|
|
|
"""
|
|
|
|
|
2021-04-30 19:26:30 +02:00
|
|
|
# Get batch ready for RNNs, if applicable.
|
2021-02-25 12:18:11 +01:00
|
|
|
if not isinstance(train_batch, SampleBatch) or not train_batch.zero_padded:
|
|
|
|
pad_batch_to_sequences_of_same_size(
|
|
|
|
train_batch,
|
|
|
|
max_seq_len=self._max_seq_len,
|
|
|
|
shuffle=shuffle,
|
|
|
|
batch_divisibility_req=self._batch_divisibility_req,
|
|
|
|
feature_keys=list(self._loss_input_dict_no_rnn.keys()),
|
|
|
|
view_requirements=self.view_requirements,
|
|
|
|
)
|
2020-12-21 02:22:32 +01:00
|
|
|
|
|
|
|
# Mark the batch as "is_training" so the Model can use this
|
|
|
|
# information.
|
2021-11-05 16:10:00 +01:00
|
|
|
train_batch.set_training(True)
|
2020-04-01 07:00:28 +02:00
|
|
|
|
|
|
|
# Build the feed dict from the batch.
|
2019-05-20 16:46:05 -07:00
|
|
|
feed_dict = {}
|
2021-09-23 12:56:45 +02:00
|
|
|
for key, placeholders in self._loss_input_dict.items():
|
2022-04-12 07:50:09 +02:00
|
|
|
a = tree.map_structure(
|
2021-09-23 12:56:45 +02:00
|
|
|
lambda ph, v: feed_dict.__setitem__(ph, v),
|
|
|
|
placeholders,
|
|
|
|
train_batch[key],
|
|
|
|
)
|
2022-04-12 07:50:09 +02:00
|
|
|
del a
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
state_keys = ["state_in_{}".format(i) for i in range(len(self._state_inputs))]
|
2020-11-12 16:27:34 +01:00
|
|
|
for key in state_keys:
|
2020-12-21 02:22:32 +01:00
|
|
|
feed_dict[self._loss_input_dict[key]] = train_batch[key]
|
2020-04-01 07:00:28 +02:00
|
|
|
if state_keys:
|
2021-08-21 17:05:48 +02:00
|
|
|
feed_dict[self._seq_lens] = train_batch[SampleBatch.SEQ_LENS]
|
2020-04-01 07:00:28 +02:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
return feed_dict
|