2019-09-17 04:44:20 -04:00
|
|
|
"""Graph mode TF policy built using build_tf_policy()."""
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
import logging
|
|
|
|
import numpy as np
|
|
|
|
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.util.debug import log_once
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.policy.tf_policy import TFPolicy
|
2019-05-18 00:23:11 -07:00
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.annotations import override
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.rllib.utils.debug import summarize
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2019-05-18 00:23:11 -07:00
|
|
|
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
class DynamicTFPolicy(TFPolicy):
|
|
|
|
"""A TFPolicy that auto-defines placeholders dynamically at runtime.
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
Initialization of this class occurs in two phases.
|
|
|
|
* Phase 1: the model is created and model variables are initialized.
|
|
|
|
* Phase 2: a fake batch of data is created, sent to the trajectory
|
|
|
|
postprocessor, and then used to create placeholders for the loss
|
|
|
|
function. The loss and stats functions are initialized with these
|
|
|
|
placeholders.
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2019-07-21 12:27:17 -07:00
|
|
|
Initialization defines the static graph.
|
2019-09-08 23:01:26 -07:00
|
|
|
|
|
|
|
Attributes:
|
|
|
|
observation_space (gym.Space): observation space of the policy.
|
|
|
|
action_space (gym.Space): action space of the policy.
|
|
|
|
config (dict): config of the policy
|
|
|
|
model (TorchModel): TF model instance
|
|
|
|
dist_class (type): TF action distribution class
|
2019-05-18 00:23:11 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
|
|
|
config,
|
|
|
|
loss_fn,
|
|
|
|
stats_fn=None,
|
|
|
|
grad_stats_fn=None,
|
|
|
|
before_loss_init=None,
|
2019-07-03 15:59:47 -07:00
|
|
|
make_model=None,
|
|
|
|
action_sampler_fn=None,
|
2020-04-01 09:43:21 +02:00
|
|
|
action_distribution_fn=None,
|
2019-05-18 00:23:11 -07:00
|
|
|
existing_inputs=None,
|
2019-07-03 15:59:47 -07:00
|
|
|
existing_model=None,
|
2019-06-02 14:14:31 +08:00
|
|
|
get_batch_divisibility_req=None,
|
|
|
|
obs_include_prev_action_reward=True):
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Initialize a dynamic TF policy.
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
Arguments:
|
|
|
|
observation_space (gym.Space): Observation space of the policy.
|
|
|
|
action_space (gym.Space): Action space of the policy.
|
|
|
|
config (dict): Policy-specific configuration data.
|
|
|
|
loss_fn (func): function that returns a loss tensor the policy
|
|
|
|
graph, and dict of experience tensor placeholders
|
|
|
|
stats_fn (func): optional function that returns a dict of
|
2019-05-20 16:46:05 -07:00
|
|
|
TF fetches given the policy and batch input tensors
|
2019-05-18 00:23:11 -07:00
|
|
|
grad_stats_fn (func): optional function that returns a dict of
|
2019-05-20 16:46:05 -07:00
|
|
|
TF fetches given the policy and loss gradient tensors
|
2020-01-18 07:26:28 +01:00
|
|
|
before_loss_init (Optional[callable]): Optional function to run
|
|
|
|
prior to loss init that takes the same arguments as __init__.
|
2019-07-03 15:59:47 -07:00
|
|
|
make_model (func): optional function that returns a ModelV2 object
|
|
|
|
given (policy, obs_space, action_space, config).
|
|
|
|
All policy variables should be created in this function. If not
|
|
|
|
specified, a default model will be created.
|
2020-02-22 23:19:49 +01:00
|
|
|
action_sampler_fn (Optional[callable]): An optional callable
|
2020-04-01 09:43:21 +02:00
|
|
|
returning a tuple of action and action prob tensors given
|
|
|
|
(policy, model, input_dict, obs_space, action_space, config).
|
|
|
|
If None, a default action distribution will be used.
|
|
|
|
action_distribution_fn (Optional[callable]): A callable returning
|
|
|
|
distribution inputs (parameters), a dist-class to generate an
|
|
|
|
action distribution object from, and internal-state outputs
|
|
|
|
(or an empty list if not applicable).
|
|
|
|
Note: No Exploration hooks have to be called from within
|
|
|
|
`action_distribution_fn`. It's should only perform a simple
|
|
|
|
forward pass through some model.
|
|
|
|
If None, pass inputs through `self.model()` to get the
|
|
|
|
distribution inputs.
|
2020-01-18 07:26:28 +01:00
|
|
|
existing_inputs (OrderedDict): When copying a policy, this
|
2019-05-18 00:23:11 -07:00
|
|
|
specifies an existing dict of placeholders to use instead of
|
|
|
|
defining new ones
|
2019-07-03 15:59:47 -07:00
|
|
|
existing_model (ModelV2): when copying a policy, this specifies
|
|
|
|
an existing model to clone and share weights with
|
2019-05-18 00:23:11 -07:00
|
|
|
get_batch_divisibility_req (func): optional function that returns
|
|
|
|
the divisibility requirement for sample batches
|
2019-06-02 14:14:31 +08:00
|
|
|
obs_include_prev_action_reward (bool): whether to include the
|
|
|
|
previous action and reward in the model input
|
2019-05-18 00:23:11 -07:00
|
|
|
"""
|
2020-04-01 09:43:21 +02:00
|
|
|
self.observation_space = obs_space
|
|
|
|
self.action_space = action_space
|
2019-05-18 00:23:11 -07:00
|
|
|
self.config = config
|
2020-02-19 21:18:45 +01:00
|
|
|
self.framework = "tf"
|
2019-05-18 00:23:11 -07:00
|
|
|
self._loss_fn = loss_fn
|
|
|
|
self._stats_fn = stats_fn
|
|
|
|
self._grad_stats_fn = grad_stats_fn
|
2019-06-02 14:14:31 +08:00
|
|
|
self._obs_include_prev_action_reward = obs_include_prev_action_reward
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
# Setup standard placeholders
|
2019-06-02 14:14:31 +08:00
|
|
|
prev_actions = None
|
|
|
|
prev_rewards = None
|
2019-05-18 00:23:11 -07:00
|
|
|
if existing_inputs is not None:
|
|
|
|
obs = existing_inputs[SampleBatch.CUR_OBS]
|
2019-06-02 14:14:31 +08:00
|
|
|
if self._obs_include_prev_action_reward:
|
|
|
|
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
|
|
|
|
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
|
2020-02-22 23:19:49 +01:00
|
|
|
action_input = existing_inputs[SampleBatch.ACTIONS]
|
2020-04-03 19:44:58 +02:00
|
|
|
explore = existing_inputs["is_exploring"]
|
|
|
|
timestep = existing_inputs["timestep"]
|
2019-05-18 00:23:11 -07:00
|
|
|
else:
|
2020-06-30 10:13:20 +02:00
|
|
|
obs = tf1.placeholder(
|
2019-05-18 00:23:11 -07:00
|
|
|
tf.float32,
|
|
|
|
shape=[None] + list(obs_space.shape),
|
|
|
|
name="observation")
|
2020-02-22 23:19:49 +01:00
|
|
|
action_input = ModelCatalog.get_action_placeholder(action_space)
|
2019-06-02 14:14:31 +08:00
|
|
|
if self._obs_include_prev_action_reward:
|
|
|
|
prev_actions = ModelCatalog.get_action_placeholder(
|
2020-02-19 21:18:45 +01:00
|
|
|
action_space, "prev_action")
|
2020-06-30 10:13:20 +02:00
|
|
|
prev_rewards = tf1.placeholder(
|
2019-06-02 14:14:31 +08:00
|
|
|
tf.float32, [None], name="prev_reward")
|
2020-06-30 10:13:20 +02:00
|
|
|
explore = tf1.placeholder_with_default(
|
2020-04-03 19:44:58 +02:00
|
|
|
True, (), name="is_exploring")
|
2020-06-30 10:13:20 +02:00
|
|
|
timestep = tf1.placeholder(tf.int32, (), name="timestep")
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
self._input_dict = {
|
2019-06-02 14:14:31 +08:00
|
|
|
SampleBatch.CUR_OBS: obs,
|
|
|
|
SampleBatch.PREV_ACTIONS: prev_actions,
|
|
|
|
SampleBatch.PREV_REWARDS: prev_rewards,
|
2019-05-18 00:23:11 -07:00
|
|
|
"is_training": self._get_is_training_placeholder(),
|
|
|
|
}
|
2020-02-11 00:22:07 +01:00
|
|
|
# Placeholder for RNN time-chunk valid lengths.
|
2020-06-30 10:13:20 +02:00
|
|
|
self._seq_lens = tf1.placeholder(
|
2019-07-03 15:59:47 -07:00
|
|
|
dtype=tf.int32, shape=[None], name="seq_lens")
|
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
dist_class = dist_inputs = None
|
|
|
|
if action_sampler_fn or action_distribution_fn:
|
2019-07-24 13:55:55 -07:00
|
|
|
if not make_model:
|
|
|
|
raise ValueError(
|
2020-04-01 09:43:21 +02:00
|
|
|
"`make_model` is required if `action_sampler_fn` OR "
|
|
|
|
"`action_distribution_fn` is given")
|
2019-07-24 13:55:55 -07:00
|
|
|
else:
|
2020-04-01 09:43:21 +02:00
|
|
|
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
2019-07-24 13:55:55 -07:00
|
|
|
action_space, self.config["model"])
|
2019-07-29 15:02:32 -07:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
# Setup self.model.
|
2019-07-03 15:59:47 -07:00
|
|
|
if existing_model:
|
|
|
|
self.model = existing_model
|
|
|
|
elif make_model:
|
|
|
|
self.model = make_model(self, obs_space, action_space, config)
|
2019-05-18 00:23:11 -07:00
|
|
|
else:
|
2019-07-03 15:59:47 -07:00
|
|
|
self.model = ModelCatalog.get_model_v2(
|
2020-05-08 08:20:18 +02:00
|
|
|
obs_space=obs_space,
|
|
|
|
action_space=action_space,
|
|
|
|
num_outputs=logit_dim,
|
|
|
|
model_config=self.config["model"],
|
2020-05-18 17:26:40 +02:00
|
|
|
framework="tf",
|
2020-05-27 10:19:47 +02:00
|
|
|
**self.config["model"].get("custom_model_config", {}))
|
2019-08-10 14:05:12 -07:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
# Create the Exploration object to use for this Policy.
|
|
|
|
self.exploration = self._create_exploration()
|
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
if existing_inputs:
|
2019-08-23 02:21:11 -04:00
|
|
|
self._state_in = [
|
2019-07-03 15:59:47 -07:00
|
|
|
v for k, v in existing_inputs.items()
|
|
|
|
if k.startswith("state_in_")
|
|
|
|
]
|
2019-08-23 02:21:11 -04:00
|
|
|
if self._state_in:
|
|
|
|
self._seq_lens = existing_inputs["seq_lens"]
|
2019-07-03 15:59:47 -07:00
|
|
|
else:
|
2019-08-23 02:21:11 -04:00
|
|
|
self._state_in = [
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1.placeholder(shape=(None, ) + s.shape, dtype=s.dtype)
|
2019-07-03 15:59:47 -07:00
|
|
|
for s in self.model.get_initial_state()
|
|
|
|
]
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
# Fully customized action generation (e.g., custom policy).
|
2019-07-03 15:59:47 -07:00
|
|
|
if action_sampler_fn:
|
2020-02-22 23:19:49 +01:00
|
|
|
sampled_action, sampled_action_logp = action_sampler_fn(
|
2020-04-01 09:43:21 +02:00
|
|
|
self,
|
|
|
|
self.model,
|
|
|
|
obs_batch=self._input_dict[SampleBatch.CUR_OBS],
|
|
|
|
state_batches=self._state_in,
|
|
|
|
seq_lens=self._seq_lens,
|
|
|
|
prev_action_batch=self._input_dict[SampleBatch.PREV_ACTIONS],
|
|
|
|
prev_reward_batch=self._input_dict[SampleBatch.PREV_REWARDS],
|
|
|
|
explore=explore,
|
|
|
|
is_training=self._input_dict["is_training"])
|
2019-07-03 15:59:47 -07:00
|
|
|
else:
|
2020-04-01 09:43:21 +02:00
|
|
|
# Distribution generation is customized, e.g., DQN, DDPG.
|
|
|
|
if action_distribution_fn:
|
|
|
|
dist_inputs, dist_class, self._state_out = \
|
|
|
|
action_distribution_fn(
|
|
|
|
self, self.model,
|
|
|
|
obs_batch=self._input_dict[SampleBatch.CUR_OBS],
|
|
|
|
state_batches=self._state_in,
|
|
|
|
seq_lens=self._seq_lens,
|
|
|
|
prev_action_batch=self._input_dict[
|
|
|
|
SampleBatch.PREV_ACTIONS],
|
|
|
|
prev_reward_batch=self._input_dict[
|
|
|
|
SampleBatch.PREV_REWARDS],
|
|
|
|
explore=explore,
|
|
|
|
is_training=self._input_dict["is_training"])
|
|
|
|
# Default distribution generation behavior:
|
|
|
|
# Pass through model. E.g., PG, PPO.
|
|
|
|
else:
|
|
|
|
dist_inputs, self._state_out = self.model(
|
|
|
|
self._input_dict, self._state_in, self._seq_lens)
|
|
|
|
|
|
|
|
action_dist = dist_class(dist_inputs, self.model)
|
|
|
|
|
|
|
|
# Using exploration to get final action (e.g. via sampling).
|
2020-02-22 23:19:49 +01:00
|
|
|
sampled_action, sampled_action_logp = \
|
2020-02-19 21:18:45 +01:00
|
|
|
self.exploration.get_exploration_action(
|
2020-04-01 09:43:21 +02:00
|
|
|
action_distribution=action_dist,
|
|
|
|
timestep=timestep,
|
2020-03-04 13:00:37 -08:00
|
|
|
explore=explore)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2020-02-22 23:19:49 +01:00
|
|
|
# Phase 1 init.
|
2020-06-30 10:13:20 +02:00
|
|
|
sess = tf1.get_default_session() or tf1.Session()
|
2019-05-18 00:23:11 -07:00
|
|
|
if get_batch_divisibility_req:
|
|
|
|
batch_divisibility_req = get_batch_divisibility_req(self)
|
|
|
|
else:
|
|
|
|
batch_divisibility_req = 1
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
super().__init__(
|
2020-04-15 13:25:16 +02:00
|
|
|
observation_space=obs_space,
|
|
|
|
action_space=action_space,
|
|
|
|
config=config,
|
|
|
|
sess=sess,
|
2019-05-18 00:23:11 -07:00
|
|
|
obs_input=obs,
|
2020-02-22 23:19:49 +01:00
|
|
|
action_input=action_input, # for logp calculations
|
|
|
|
sampled_action=sampled_action,
|
|
|
|
sampled_action_logp=sampled_action_logp,
|
2020-04-01 09:43:21 +02:00
|
|
|
dist_inputs=dist_inputs,
|
|
|
|
dist_class=dist_class,
|
2019-05-18 00:23:11 -07:00
|
|
|
loss=None, # dynamically initialized on run
|
|
|
|
loss_inputs=[],
|
|
|
|
model=self.model,
|
2019-08-23 02:21:11 -04:00
|
|
|
state_inputs=self._state_in,
|
|
|
|
state_outputs=self._state_out,
|
2019-05-18 00:23:11 -07:00
|
|
|
prev_action_input=prev_actions,
|
|
|
|
prev_reward_input=prev_rewards,
|
2019-08-23 02:21:11 -04:00
|
|
|
seq_lens=self._seq_lens,
|
2019-05-18 00:23:11 -07:00
|
|
|
max_seq_len=config["model"]["max_seq_len"],
|
2020-02-19 21:18:45 +01:00
|
|
|
batch_divisibility_req=batch_divisibility_req,
|
|
|
|
explore=explore,
|
|
|
|
timestep=timestep)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
# Phase 2 init.
|
|
|
|
if before_loss_init is not None:
|
|
|
|
before_loss_init(self, obs_space, action_space, config)
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
if not existing_inputs:
|
|
|
|
self._initialize_loss()
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@override(TFPolicy)
|
2019-05-18 00:23:11 -07:00
|
|
|
def copy(self, existing_inputs):
|
|
|
|
"""Creates a copy of self using existing input placeholders."""
|
|
|
|
|
|
|
|
# Note that there might be RNN state inputs at the end of the list
|
|
|
|
if self._state_inputs:
|
|
|
|
num_state_inputs = len(self._state_inputs) + 1
|
|
|
|
else:
|
|
|
|
num_state_inputs = 0
|
|
|
|
if len(self._loss_inputs) + num_state_inputs != len(existing_inputs):
|
|
|
|
raise ValueError("Tensor list mismatch", self._loss_inputs,
|
|
|
|
self._state_inputs, existing_inputs)
|
|
|
|
for i, (k, v) in enumerate(self._loss_inputs):
|
|
|
|
if v.shape.as_list() != existing_inputs[i].shape.as_list():
|
|
|
|
raise ValueError("Tensor shape mismatch", i, k, v.shape,
|
|
|
|
existing_inputs[i].shape)
|
|
|
|
# By convention, the loss inputs are followed by state inputs and then
|
|
|
|
# the seq len tensor
|
|
|
|
rnn_inputs = []
|
|
|
|
for i in range(len(self._state_inputs)):
|
|
|
|
rnn_inputs.append(("state_in_{}".format(i),
|
|
|
|
existing_inputs[len(self._loss_inputs) + i]))
|
|
|
|
if rnn_inputs:
|
|
|
|
rnn_inputs.append(("seq_lens", existing_inputs[-1]))
|
2020-04-03 19:44:58 +02:00
|
|
|
input_dict = OrderedDict([("is_exploring", self._is_exploring), (
|
|
|
|
"timestep", self._timestep)] + [(k, existing_inputs[i]) for i, (
|
|
|
|
k, _) in enumerate(self._loss_inputs)] + rnn_inputs)
|
2019-05-18 00:23:11 -07:00
|
|
|
instance = self.__class__(
|
|
|
|
self.observation_space,
|
|
|
|
self.action_space,
|
|
|
|
self.config,
|
2019-07-03 15:59:47 -07:00
|
|
|
existing_inputs=input_dict,
|
|
|
|
existing_model=self.model)
|
2019-06-02 14:14:31 +08:00
|
|
|
|
2019-08-08 14:03:28 -07:00
|
|
|
instance._loss_input_dict = input_dict
|
2019-06-02 14:14:31 +08:00
|
|
|
loss = instance._do_loss_init(input_dict)
|
2019-07-03 15:59:47 -07:00
|
|
|
loss_inputs = [(k, existing_inputs[i])
|
|
|
|
for i, (k, _) in enumerate(self._loss_inputs)]
|
|
|
|
|
|
|
|
TFPolicy._initialize_loss(instance, loss, loss_inputs)
|
2019-05-18 00:23:11 -07:00
|
|
|
if instance._grad_stats_fn:
|
|
|
|
instance._stats_fetches.update(
|
2019-08-08 14:03:28 -07:00
|
|
|
instance._grad_stats_fn(instance, input_dict, instance._grads))
|
2019-05-18 00:23:11 -07:00
|
|
|
return instance
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@override(Policy)
|
2019-05-18 00:23:11 -07:00
|
|
|
def get_initial_state(self):
|
|
|
|
if self.model:
|
2019-07-03 15:59:47 -07:00
|
|
|
return self.model.get_initial_state()
|
2019-05-18 00:23:11 -07:00
|
|
|
else:
|
|
|
|
return []
|
|
|
|
|
|
|
|
def _initialize_loss(self):
|
|
|
|
def fake_array(tensor):
|
|
|
|
shape = tensor.shape.as_list()
|
2019-08-06 18:13:16 +00:00
|
|
|
shape = [s if s is not None else 1 for s in shape]
|
2019-05-18 00:23:11 -07:00
|
|
|
return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)
|
|
|
|
|
|
|
|
dummy_batch = {
|
|
|
|
SampleBatch.CUR_OBS: fake_array(self._obs_input),
|
|
|
|
SampleBatch.NEXT_OBS: fake_array(self._obs_input),
|
|
|
|
SampleBatch.DONES: np.array([False], dtype=np.bool),
|
2019-06-02 14:14:31 +08:00
|
|
|
SampleBatch.ACTIONS: fake_array(
|
|
|
|
ModelCatalog.get_action_placeholder(self.action_space)),
|
|
|
|
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
|
2019-05-18 00:23:11 -07:00
|
|
|
}
|
2019-06-02 14:14:31 +08:00
|
|
|
if self._obs_include_prev_action_reward:
|
|
|
|
dummy_batch.update({
|
|
|
|
SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
|
|
|
|
SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
|
|
|
|
})
|
2019-05-18 00:23:11 -07:00
|
|
|
state_init = self.get_initial_state()
|
2019-08-23 02:21:11 -04:00
|
|
|
state_batches = []
|
2019-05-18 00:23:11 -07:00
|
|
|
for i, h in enumerate(state_init):
|
|
|
|
dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
|
|
|
|
dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
|
2019-08-23 02:21:11 -04:00
|
|
|
state_batches.append(np.expand_dims(h, 0))
|
2019-05-18 00:23:11 -07:00
|
|
|
if state_init:
|
|
|
|
dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
|
|
|
|
for k, v in self.extra_compute_action_fetches().items():
|
|
|
|
dummy_batch[k] = fake_array(v)
|
|
|
|
|
|
|
|
# postprocessing might depend on variable init, so run it first here
|
2020-06-30 10:13:20 +02:00
|
|
|
self._sess.run(tf1.global_variables_initializer())
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
postprocessed_batch = self.postprocess_trajectory(
|
|
|
|
SampleBatch(dummy_batch))
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
# model forward pass for the loss (needed after postprocess to
|
|
|
|
# overwrite any tensor state from that call)
|
|
|
|
self.model(self._input_dict, self._state_in, self._seq_lens)
|
|
|
|
|
2019-06-02 14:14:31 +08:00
|
|
|
if self._obs_include_prev_action_reward:
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch = UsageTrackingDict({
|
2019-06-02 14:14:31 +08:00
|
|
|
SampleBatch.PREV_ACTIONS: self._prev_action_input,
|
|
|
|
SampleBatch.PREV_REWARDS: self._prev_reward_input,
|
|
|
|
SampleBatch.CUR_OBS: self._obs_input,
|
|
|
|
})
|
|
|
|
loss_inputs = [
|
|
|
|
(SampleBatch.PREV_ACTIONS, self._prev_action_input),
|
|
|
|
(SampleBatch.PREV_REWARDS, self._prev_reward_input),
|
|
|
|
(SampleBatch.CUR_OBS, self._obs_input),
|
|
|
|
]
|
|
|
|
else:
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch = UsageTrackingDict({
|
2019-06-02 14:14:31 +08:00
|
|
|
SampleBatch.CUR_OBS: self._obs_input,
|
|
|
|
})
|
|
|
|
loss_inputs = [
|
|
|
|
(SampleBatch.CUR_OBS, self._obs_input),
|
|
|
|
]
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
for k, v in postprocessed_batch.items():
|
2019-08-23 02:21:11 -04:00
|
|
|
if k in train_batch:
|
2019-05-18 00:23:11 -07:00
|
|
|
continue
|
|
|
|
elif v.dtype == np.object:
|
|
|
|
continue # can't handle arbitrary objects in TF
|
2019-08-23 02:21:11 -04:00
|
|
|
elif k == "seq_lens" or k.startswith("state_in_"):
|
|
|
|
continue
|
2019-05-18 00:23:11 -07:00
|
|
|
shape = (None, ) + v.shape[1:]
|
|
|
|
dtype = np.float32 if v.dtype == np.float64 else v.dtype
|
2020-06-30 10:13:20 +02:00
|
|
|
placeholder = tf1.placeholder(dtype, shape=shape, name=k)
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch[k] = placeholder
|
|
|
|
|
|
|
|
for i, si in enumerate(self._state_in):
|
|
|
|
train_batch["state_in_{}".format(i)] = si
|
|
|
|
train_batch["seq_lens"] = self._seq_lens
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
if log_once("loss_init"):
|
2019-08-05 13:23:54 -07:00
|
|
|
logger.debug(
|
2019-05-18 00:23:11 -07:00
|
|
|
"Initializing loss function with dummy input:\n\n{}\n".format(
|
2019-08-23 02:21:11 -04:00
|
|
|
summarize(train_batch)))
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
self._loss_input_dict = train_batch
|
|
|
|
loss = self._do_loss_init(train_batch)
|
|
|
|
for k in sorted(train_batch.accessed_keys):
|
|
|
|
if k != "seq_lens" and not k.startswith("state_in_"):
|
|
|
|
loss_inputs.append((k, train_batch[k]))
|
2019-06-07 16:42:37 -07:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
TFPolicy._initialize_loss(self, loss, loss_inputs)
|
2019-05-18 00:23:11 -07:00
|
|
|
if self._grad_stats_fn:
|
2019-08-08 14:03:28 -07:00
|
|
|
self._stats_fetches.update(
|
2019-08-23 02:21:11 -04:00
|
|
|
self._grad_stats_fn(self, train_batch, self._grads))
|
2020-06-30 10:13:20 +02:00
|
|
|
self._sess.run(tf1.global_variables_initializer())
|
2019-06-02 14:14:31 +08:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def _do_loss_init(self, train_batch):
|
2019-09-08 23:01:26 -07:00
|
|
|
loss = self._loss_fn(self, self.model, self.dist_class, train_batch)
|
2019-06-02 14:14:31 +08:00
|
|
|
if self._stats_fn:
|
2019-08-23 02:21:11 -04:00
|
|
|
self._stats_fetches.update(self._stats_fn(self, train_batch))
|
2019-07-24 13:55:55 -07:00
|
|
|
# override the update ops to be those of the model
|
|
|
|
self._update_ops = self.model.update_ops()
|
2019-06-02 14:14:31 +08:00
|
|
|
return loss
|