mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
1048 lines
40 KiB
Python
1048 lines
40 KiB
Python
from collections import OrderedDict
|
|
import gym
|
|
import logging
|
|
import re
|
|
import tree # pip install dm_tree
|
|
from typing import Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
|
|
|
from ray.util.debug import log_once
|
|
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.policy.dynamic_tf_policy import TFMultiGPUTowerStack
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.tf_policy import TFPolicy
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
from ray.rllib.utils import force_list
|
|
from ray.rllib.utils.annotations import (
|
|
DeveloperAPI,
|
|
OverrideToImplementCustomLogic,
|
|
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
|
is_overridden,
|
|
override,
|
|
)
|
|
from ray.rllib.utils.debug import summarize
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
|
from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
|
|
from ray.rllib.utils.tf_utils import get_placeholder
|
|
from ray.rllib.utils.typing import (
|
|
LocalOptimizer,
|
|
ModelGradients,
|
|
TensorType,
|
|
TrainerConfigDict,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from ray.rllib.evaluation import Episode
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@DeveloperAPI
|
|
class DynamicTFPolicyV2(TFPolicy):
|
|
"""A TFPolicy that auto-defines placeholders dynamically at runtime.
|
|
|
|
This class is intended to be used and extended by sub-classing.
|
|
"""
|
|
|
|
@DeveloperAPI
|
|
def __init__(
|
|
self,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: TrainerConfigDict,
|
|
*,
|
|
existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
|
|
existing_model: Optional[ModelV2] = None,
|
|
):
|
|
self.observation_space = obs_space
|
|
self.action_space = action_space
|
|
config = dict(self.get_default_config(), **config)
|
|
self.config = config
|
|
self.framework = "tf"
|
|
self._seq_lens = None
|
|
self._is_tower = existing_inputs is not None
|
|
|
|
self.validate_spaces(obs_space, action_space, config)
|
|
|
|
self.dist_class = self._init_dist_class()
|
|
# Setup self.model.
|
|
if existing_model and isinstance(existing_model, list):
|
|
self.model = existing_model[0]
|
|
# TODO: (sven) hack, but works for `target_[q_]?model`.
|
|
for i in range(1, len(existing_model)):
|
|
setattr(self, existing_model[i][0], existing_model[i][1])
|
|
else:
|
|
self.model = self.make_model()
|
|
# Auto-update model's inference view requirements, if recurrent.
|
|
self._update_model_view_requirements_from_init_state()
|
|
|
|
self._init_state_inputs(existing_inputs)
|
|
self._init_view_requirements()
|
|
timestep, explore = self._init_input_dict_and_dummy_batch(existing_inputs)
|
|
(
|
|
sampled_action,
|
|
sampled_action_logp,
|
|
dist_inputs,
|
|
self._policy_extra_action_fetches,
|
|
) = self._init_action_fetches(timestep, explore)
|
|
|
|
# Phase 1 init.
|
|
sess = tf1.get_default_session() or tf1.Session(
|
|
config=tf1.ConfigProto(**self.config["tf_session_args"])
|
|
)
|
|
|
|
batch_divisibility_req = self.get_batch_divisibility_req()
|
|
|
|
prev_action_input = (
|
|
self._input_dict[SampleBatch.PREV_ACTIONS]
|
|
if SampleBatch.PREV_ACTIONS in self._input_dict.accessed_keys
|
|
else None
|
|
)
|
|
prev_reward_input = (
|
|
self._input_dict[SampleBatch.PREV_REWARDS]
|
|
if SampleBatch.PREV_REWARDS in self._input_dict.accessed_keys
|
|
else None
|
|
)
|
|
|
|
super().__init__(
|
|
observation_space=obs_space,
|
|
action_space=action_space,
|
|
config=config,
|
|
sess=sess,
|
|
obs_input=self._input_dict[SampleBatch.OBS],
|
|
action_input=self._input_dict[SampleBatch.ACTIONS],
|
|
sampled_action=sampled_action,
|
|
sampled_action_logp=sampled_action_logp,
|
|
dist_inputs=dist_inputs,
|
|
dist_class=self.dist_class,
|
|
loss=None, # dynamically initialized on run
|
|
loss_inputs=[],
|
|
model=self.model,
|
|
state_inputs=self._state_inputs,
|
|
state_outputs=self._state_out,
|
|
prev_action_input=prev_action_input,
|
|
prev_reward_input=prev_reward_input,
|
|
seq_lens=self._seq_lens,
|
|
max_seq_len=config["model"]["max_seq_len"],
|
|
batch_divisibility_req=batch_divisibility_req,
|
|
explore=explore,
|
|
timestep=timestep,
|
|
)
|
|
|
|
@DeveloperAPI
|
|
@staticmethod
|
|
def enable_eager_execution_if_necessary():
|
|
# This is static graph TF policy.
|
|
# Simply do nothing.
|
|
pass
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
def get_default_config(self) -> TrainerConfigDict:
|
|
return {}
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
def validate_spaces(
|
|
self,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: TrainerConfigDict,
|
|
):
|
|
return {}
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
@override(Policy)
|
|
def loss(
|
|
self,
|
|
model: Union[ModelV2, "tf.keras.Model"],
|
|
dist_class: Type[TFActionDistribution],
|
|
train_batch: SampleBatch,
|
|
) -> Union[TensorType, List[TensorType]]:
|
|
"""Constructs loss computation graph for this TF1 policy.
|
|
|
|
Args:
|
|
model: The Model to calculate the loss for.
|
|
dist_class: The action distr. class.
|
|
train_batch: The training data.
|
|
|
|
Returns:
|
|
A single loss tensor or a list of loss tensors.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
|
"""Stats function. Returns a dict of statistics.
|
|
|
|
Args:
|
|
train_batch: The SampleBatch (already) used for training.
|
|
|
|
Returns:
|
|
The stats dict.
|
|
"""
|
|
return {}
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
def grad_stats_fn(
|
|
self, train_batch: SampleBatch, grads: ModelGradients
|
|
) -> Dict[str, TensorType]:
|
|
"""Gradient stats function. Returns a dict of statistics.
|
|
|
|
Args:
|
|
train_batch: The SampleBatch (already) used for training.
|
|
|
|
Returns:
|
|
The stats dict.
|
|
"""
|
|
return {}
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
def make_model(self) -> ModelV2:
|
|
"""Build underlying model for this Policy.
|
|
|
|
Returns:
|
|
The Model for the Policy to use.
|
|
"""
|
|
# Default ModelV2 model.
|
|
_, logit_dim = ModelCatalog.get_action_dist(
|
|
self.action_space, self.config["model"]
|
|
)
|
|
return ModelCatalog.get_model_v2(
|
|
obs_space=self.observation_space,
|
|
action_space=self.action_space,
|
|
num_outputs=logit_dim,
|
|
model_config=self.config["model"],
|
|
framework="tf",
|
|
)
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
def compute_gradients_fn(
|
|
self, optimizer: LocalOptimizer, loss: TensorType
|
|
) -> ModelGradients:
|
|
"""Gradients computing function (from loss tensor, using local optimizer).
|
|
|
|
Args:
|
|
policy (Policy): The Policy object that generated the loss tensor and
|
|
that holds the given local optimizer.
|
|
optimizer (LocalOptimizer): The tf (local) optimizer object to
|
|
calculate the gradients with.
|
|
loss (TensorType): The loss tensor for which gradients should be
|
|
calculated.
|
|
|
|
Returns:
|
|
ModelGradients: List of the possibly clipped gradients- and variable
|
|
tuples.
|
|
"""
|
|
return None
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
def apply_gradients_fn(
|
|
self,
|
|
policy: Policy,
|
|
optimizer: "tf.keras.optimizers.Optimizer",
|
|
grads: ModelGradients,
|
|
) -> "tf.Operation":
|
|
"""Gradients computing function (from loss tensor, using local optimizer).
|
|
|
|
Args:
|
|
policy (Policy): The Policy object that generated the loss tensor and
|
|
that holds the given local optimizer.
|
|
optimizer (LocalOptimizer): The tf (local) optimizer object to
|
|
calculate the gradients with.
|
|
grads (ModelGradients): The gradient tensor to be applied.
|
|
|
|
Returns:
|
|
"tf.Operation": TF operation that applies supplied gradients.
|
|
"""
|
|
return None
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
def action_sampler_fn(
|
|
self,
|
|
model: ModelV2,
|
|
*,
|
|
obs_batch: TensorType,
|
|
state_batches: TensorType,
|
|
**kwargs,
|
|
) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
|
|
"""Custom function for sampling new actions given policy.
|
|
|
|
Args:
|
|
model: Underlying model.
|
|
obs_batch: Observation tensor batch.
|
|
state_batches: Action sampling state batch.
|
|
|
|
Returns:
|
|
Sampled action
|
|
Log-likelihood
|
|
Action distribution inputs
|
|
Updated state
|
|
"""
|
|
return None, None, None, None
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
def action_distribution_fn(
|
|
self,
|
|
model: ModelV2,
|
|
*,
|
|
obs_batch: TensorType,
|
|
state_batches: TensorType,
|
|
**kwargs,
|
|
) -> Tuple[TensorType, type, List[TensorType]]:
|
|
"""Action distribution function for this Policy.
|
|
|
|
Args:
|
|
model: Underlying model.
|
|
obs_batch: Observation tensor batch.
|
|
state_batches: Action sampling state batch.
|
|
|
|
Returns:
|
|
Distribution input.
|
|
ActionDistribution class.
|
|
State outs.
|
|
"""
|
|
return None, None, None
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic
|
|
def get_batch_divisibility_req(self) -> int:
|
|
"""Get batch divisibility request.
|
|
|
|
Returns:
|
|
Size N. A sample batch must be of size K*N.
|
|
"""
|
|
# By default, any sized batch is ok, so simply return 1.
|
|
return 1
|
|
|
|
@override(TFPolicy)
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
|
def extra_action_out_fn(self) -> Dict[str, TensorType]:
|
|
"""Extra values to fetch and return from compute_actions().
|
|
|
|
Returns:
|
|
Dict[str, TensorType]: An extra fetch-dict to be passed to and
|
|
returned from the compute_actions() call.
|
|
"""
|
|
extra_action_fetches = super().extra_action_out_fn()
|
|
extra_action_fetches.update(self._policy_extra_action_fetches)
|
|
return extra_action_fetches
|
|
|
|
@DeveloperAPI
|
|
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
|
def extra_learn_fetches_fn(self) -> Dict[str, TensorType]:
|
|
"""Extra stats to be reported after gradient computation.
|
|
|
|
Returns:
|
|
Dict[str, TensorType]: An extra fetch-dict.
|
|
"""
|
|
return {}
|
|
|
|
@override(TFPolicy)
|
|
def extra_compute_grad_fetches(self):
|
|
return dict({LEARNER_STATS_KEY: {}}, **self.extra_learn_fetches_fn())
|
|
|
|
@override(Policy)
|
|
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
|
def postprocess_trajectory(
|
|
self,
|
|
sample_batch: SampleBatch,
|
|
other_agent_batches: Optional[SampleBatch] = None,
|
|
episode: Optional["Episode"] = None,
|
|
):
|
|
"""Post process trajectory in the format of a SampleBatch.
|
|
|
|
Args:
|
|
sample_batch: sample_batch: batch of experiences for the policy,
|
|
which will contain at most one episode trajectory.
|
|
other_agent_batches: In a multi-agent env, this contains a
|
|
mapping of agent ids to (policy, agent_batch) tuples
|
|
containing the policy and experiences of the other agents.
|
|
episode: An optional multi-agent episode object to provide
|
|
access to all of the internal episode state, which may
|
|
be useful for model-based or multi-agent algorithms.
|
|
|
|
Returns:
|
|
The postprocessed sample batch.
|
|
"""
|
|
return Policy.postprocess_trajectory(self, sample_batch)
|
|
|
|
@override(TFPolicy)
|
|
@OverrideToImplementCustomLogic
|
|
def optimizer(
|
|
self,
|
|
) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]:
|
|
"""TF optimizer to use for policy optimization.
|
|
|
|
Returns:
|
|
A local optimizer or a list of local optimizers to use for this
|
|
Policy's Model.
|
|
"""
|
|
return super().optimizer()
|
|
|
|
def _init_dist_class(self):
|
|
if is_overridden(self.action_sampler_fn) or is_overridden(
|
|
self.action_distribution_fn
|
|
):
|
|
if not is_overridden(self.make_model):
|
|
raise ValueError(
|
|
"`make_model` is required if `action_sampler_fn` OR "
|
|
"`action_distribution_fn` is given"
|
|
)
|
|
else:
|
|
dist_class, _ = ModelCatalog.get_action_dist(
|
|
self.action_space, self.config["model"]
|
|
)
|
|
return dist_class
|
|
|
|
def _init_view_requirements(self):
|
|
# If ViewRequirements are explicitly specified.
|
|
if getattr(self, "view_requirements", None):
|
|
return
|
|
|
|
# Use default settings.
|
|
# Add NEXT_OBS, STATE_IN_0.., and others.
|
|
self.view_requirements = self._get_default_view_requirements()
|
|
# Combine view_requirements for Model and Policy.
|
|
# TODO(jungong) : models will not carry view_requirements once they
|
|
# are migrated to be organic Keras models.
|
|
self.view_requirements.update(self.model.view_requirements)
|
|
# Disable env-info placeholder.
|
|
if SampleBatch.INFOS in self.view_requirements:
|
|
self.view_requirements[SampleBatch.INFOS].used_for_training = False
|
|
|
|
def _init_state_inputs(self, existing_inputs: Dict[str, "tf1.placeholder"]):
|
|
"""Initialize input placeholders.
|
|
|
|
Args:
|
|
existing_inputs: existing placeholders.
|
|
"""
|
|
if existing_inputs:
|
|
self._state_inputs = [
|
|
v for k, v in existing_inputs.items() if k.startswith("state_in_")
|
|
]
|
|
# Placeholder for RNN time-chunk valid lengths.
|
|
if self._state_inputs:
|
|
self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS]
|
|
# Create new input placeholders.
|
|
else:
|
|
self._state_inputs = [
|
|
get_placeholder(
|
|
space=vr.space,
|
|
time_axis=not isinstance(vr.shift, int),
|
|
name=k,
|
|
)
|
|
for k, vr in self.model.view_requirements.items()
|
|
if k.startswith("state_in_")
|
|
]
|
|
# Placeholder for RNN time-chunk valid lengths.
|
|
if self._state_inputs:
|
|
self._seq_lens = tf1.placeholder(
|
|
dtype=tf.int32, shape=[None], name="seq_lens"
|
|
)
|
|
|
|
def _init_input_dict_and_dummy_batch(
|
|
self, existing_inputs: Dict[str, "tf1.placeholder"]
|
|
) -> Tuple[Union[int, TensorType], Union[bool, TensorType]]:
|
|
"""Initialized input_dict and dummy_batch data.
|
|
|
|
Args:
|
|
existing_inputs: When copying a policy, this specifies an existing
|
|
dict of placeholders to use instead of defining new ones.
|
|
|
|
Returns:
|
|
timestep: training timestep.
|
|
explore: whether this policy should explore.
|
|
"""
|
|
# Setup standard placeholders.
|
|
if self._is_tower:
|
|
assert existing_inputs is not None
|
|
timestep = existing_inputs["timestep"]
|
|
explore = False
|
|
(
|
|
self._input_dict,
|
|
self._dummy_batch,
|
|
) = self._create_input_dict_and_dummy_batch(
|
|
self.view_requirements, existing_inputs
|
|
)
|
|
else:
|
|
# Placeholder for (sampling steps) timestep (int).
|
|
timestep = tf1.placeholder_with_default(
|
|
tf.zeros((), dtype=tf.int64), (), name="timestep"
|
|
)
|
|
# Placeholder for `is_exploring` flag.
|
|
explore = tf1.placeholder_with_default(True, (), name="is_exploring")
|
|
(
|
|
self._input_dict,
|
|
self._dummy_batch,
|
|
) = self._create_input_dict_and_dummy_batch(self.view_requirements, {})
|
|
|
|
# Placeholder for `is_training` flag.
|
|
self._input_dict.set_training(self._get_is_training_placeholder())
|
|
|
|
return timestep, explore
|
|
|
|
def _create_input_dict_and_dummy_batch(self, view_requirements, existing_inputs):
|
|
"""Creates input_dict and dummy_batch for loss initialization.
|
|
|
|
Used for managing the Policy's input placeholders and for loss
|
|
initialization.
|
|
Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.
|
|
|
|
Args:
|
|
view_requirements (ViewReqs): The view requirements dict.
|
|
existing_inputs (Dict[str, tf.placeholder]): A dict of already
|
|
existing placeholders.
|
|
|
|
Returns:
|
|
Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
|
|
input_dict/dummy_batch tuple.
|
|
"""
|
|
input_dict = {}
|
|
for view_col, view_req in view_requirements.items():
|
|
# Point state_in to the already existing self._state_inputs.
|
|
mo = re.match("state_in_(\d+)", view_col)
|
|
if mo is not None:
|
|
input_dict[view_col] = self._state_inputs[int(mo.group(1))]
|
|
# State-outs (no placeholders needed).
|
|
elif view_col.startswith("state_out_"):
|
|
continue
|
|
# Skip action dist inputs placeholder (do later).
|
|
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
|
|
continue
|
|
# This is a tower: Input placeholders already exist.
|
|
elif view_col in existing_inputs:
|
|
input_dict[view_col] = existing_inputs[view_col]
|
|
# All others.
|
|
else:
|
|
time_axis = not isinstance(view_req.shift, int)
|
|
if view_req.used_for_training:
|
|
# Create a +time-axis placeholder if the shift is not an
|
|
# int (range or list of ints).
|
|
# Do not flatten actions if action flattening disabled.
|
|
if self.config.get("_disable_action_flattening") and view_col in [
|
|
SampleBatch.ACTIONS,
|
|
SampleBatch.PREV_ACTIONS,
|
|
]:
|
|
flatten = False
|
|
# Do not flatten observations if no preprocessor API used.
|
|
elif (
|
|
view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS]
|
|
and self.config["_disable_preprocessor_api"]
|
|
):
|
|
flatten = False
|
|
# Flatten everything else.
|
|
else:
|
|
flatten = True
|
|
input_dict[view_col] = get_placeholder(
|
|
space=view_req.space,
|
|
name=view_col,
|
|
time_axis=time_axis,
|
|
flatten=flatten,
|
|
)
|
|
dummy_batch = self._get_dummy_batch_from_view_requirements(batch_size=32)
|
|
|
|
return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
|
|
|
|
def _init_action_fetches(
|
|
self, timestep: Union[int, TensorType], explore: Union[bool, TensorType]
|
|
) -> Tuple[TensorType, TensorType, TensorType, type, Dict[str, TensorType]]:
|
|
"""Create action related fields for base Policy and loss initialization."""
|
|
# Multi-GPU towers do not need any action computing/exploration
|
|
# graphs.
|
|
sampled_action = None
|
|
sampled_action_logp = None
|
|
dist_inputs = None
|
|
extra_action_fetches = {}
|
|
self._state_out = None
|
|
if not self._is_tower:
|
|
# Create the Exploration object to use for this Policy.
|
|
self.exploration = self._create_exploration()
|
|
|
|
# Fully customized action generation (e.g., custom policy).
|
|
if is_overridden(self.action_sampler_fn):
|
|
(
|
|
sampled_action,
|
|
sampled_action_logp,
|
|
dist_inputs,
|
|
self._state_out,
|
|
) = self.action_sampler_fn(
|
|
self.model,
|
|
obs_batch=self._input_dict[SampleBatch.CUR_OBS],
|
|
state_batches=self._state_inputs,
|
|
seq_lens=self._seq_lens,
|
|
prev_action_batch=self._input_dict.get(SampleBatch.PREV_ACTIONS),
|
|
prev_reward_batch=self._input_dict.get(SampleBatch.PREV_REWARDS),
|
|
explore=explore,
|
|
is_training=self._input_dict.is_training,
|
|
)
|
|
# Distribution generation is customized, e.g., DQN, DDPG.
|
|
else:
|
|
if is_overridden(self.action_distribution_fn):
|
|
# Try new action_distribution_fn signature, supporting
|
|
# state_batches and seq_lens.
|
|
in_dict = self._input_dict
|
|
(
|
|
dist_inputs,
|
|
self.dist_class,
|
|
self._state_out,
|
|
) = self.action_distribution_fn(
|
|
self.model,
|
|
input_dict=in_dict,
|
|
state_batches=self._state_inputs,
|
|
seq_lens=self._seq_lens,
|
|
explore=explore,
|
|
timestep=timestep,
|
|
is_training=in_dict.is_training,
|
|
)
|
|
# Default distribution generation behavior:
|
|
# Pass through model. E.g., PG, PPO.
|
|
else:
|
|
if isinstance(self.model, tf.keras.Model):
|
|
dist_inputs, self._state_out, extra_action_fetches = self.model(
|
|
self._input_dict
|
|
)
|
|
else:
|
|
dist_inputs, self._state_out = self.model(self._input_dict)
|
|
|
|
action_dist = self.dist_class(dist_inputs, self.model)
|
|
|
|
# Using exploration to get final action (e.g. via sampling).
|
|
(
|
|
sampled_action,
|
|
sampled_action_logp,
|
|
) = self.exploration.get_exploration_action(
|
|
action_distribution=action_dist, timestep=timestep, explore=explore
|
|
)
|
|
|
|
if dist_inputs is not None:
|
|
extra_action_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
|
|
|
if sampled_action_logp is not None:
|
|
extra_action_fetches[SampleBatch.ACTION_LOGP] = sampled_action_logp
|
|
extra_action_fetches[SampleBatch.ACTION_PROB] = tf.exp(
|
|
tf.cast(sampled_action_logp, tf.float32)
|
|
)
|
|
|
|
return (
|
|
sampled_action,
|
|
sampled_action_logp,
|
|
dist_inputs,
|
|
extra_action_fetches,
|
|
)
|
|
|
|
def _init_optimizers(self):
|
|
# Create the optimizer/exploration optimizer here. Some initialization
|
|
# steps (e.g. exploration postprocessing) may need this.
|
|
optimizers = force_list(self.optimizer())
|
|
if getattr(self, "exploration", None):
|
|
optimizers = self.exploration.get_exploration_optimizer(optimizers)
|
|
|
|
# No optimizers produced -> Return.
|
|
if not optimizers:
|
|
return
|
|
|
|
# The list of local (tf) optimizers (one per loss term).
|
|
self._optimizers = optimizers
|
|
# Backward compatibility.
|
|
self._optimizer = optimizers[0]
|
|
|
|
def maybe_initialize_optimizer_and_loss(self):
|
|
# We don't need to initialize loss calculation for MultiGPUTowerStack.
|
|
if self._is_tower:
|
|
return
|
|
|
|
# Loss initialization and model/postprocessing test calls.
|
|
self._init_optimizers()
|
|
self._initialize_loss_from_dummy_batch(auto_remove_unneeded_view_reqs=True)
|
|
|
|
# Create MultiGPUTowerStacks, if we have at least one actual
|
|
# GPU or >1 CPUs (fake GPUs).
|
|
if len(self.devices) > 1 or any("gpu" in d for d in self.devices):
|
|
# Per-GPU graph copies created here must share vars with the
|
|
# policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because
|
|
# Adam nodes are created after all of the device copies are
|
|
# created.
|
|
with tf1.variable_scope("", reuse=tf1.AUTO_REUSE):
|
|
self.multi_gpu_tower_stacks = [
|
|
TFMultiGPUTowerStack(policy=self)
|
|
for _ in range(self.config.get("num_multi_gpu_tower_stacks", 1))
|
|
]
|
|
|
|
# Initialize again after loss and tower init.
|
|
self.get_session().run(tf1.global_variables_initializer())
|
|
|
|
@override(Policy)
|
|
def _initialize_loss_from_dummy_batch(
|
|
self, auto_remove_unneeded_view_reqs: bool = True
|
|
) -> None:
|
|
# Test calls depend on variable init, so initialize model first.
|
|
self.get_session().run(tf1.global_variables_initializer())
|
|
|
|
# Fields that have not been accessed are not needed for action
|
|
# computations -> Tag them as `used_for_compute_actions=False`.
|
|
for key, view_req in self.view_requirements.items():
|
|
if (
|
|
not key.startswith("state_in_")
|
|
and key not in self._input_dict.accessed_keys
|
|
):
|
|
view_req.used_for_compute_actions = False
|
|
for key, value in self.extra_action_out_fn().items():
|
|
self._dummy_batch[key] = get_dummy_batch_for_space(
|
|
gym.spaces.Box(
|
|
-1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name
|
|
),
|
|
batch_size=len(self._dummy_batch),
|
|
)
|
|
self._input_dict[key] = get_placeholder(value=value, name=key)
|
|
if key not in self.view_requirements:
|
|
logger.info("Adding extra-action-fetch `{}` to view-reqs.".format(key))
|
|
self.view_requirements[key] = ViewRequirement(
|
|
space=gym.spaces.Box(
|
|
-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype.name
|
|
),
|
|
used_for_compute_actions=False,
|
|
)
|
|
dummy_batch = self._dummy_batch
|
|
|
|
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
|
|
self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session())
|
|
_ = self.postprocess_trajectory(dummy_batch)
|
|
# Add new columns automatically to (loss) input_dict.
|
|
for key in dummy_batch.added_keys:
|
|
if key not in self._input_dict:
|
|
self._input_dict[key] = get_placeholder(
|
|
value=dummy_batch[key], name=key
|
|
)
|
|
if key not in self.view_requirements:
|
|
self.view_requirements[key] = ViewRequirement(
|
|
space=gym.spaces.Box(
|
|
-1.0,
|
|
1.0,
|
|
shape=dummy_batch[key].shape[1:],
|
|
dtype=dummy_batch[key].dtype,
|
|
),
|
|
used_for_compute_actions=False,
|
|
)
|
|
|
|
train_batch = SampleBatch(
|
|
dict(self._input_dict, **self._loss_input_dict),
|
|
_is_training=True,
|
|
)
|
|
|
|
if self._state_inputs:
|
|
train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
|
|
self._loss_input_dict.update(
|
|
{SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}
|
|
)
|
|
|
|
self._loss_input_dict.update({k: v for k, v in train_batch.items()})
|
|
|
|
if log_once("loss_init"):
|
|
logger.debug(
|
|
"Initializing loss function with dummy input:\n\n{}\n".format(
|
|
summarize(train_batch)
|
|
)
|
|
)
|
|
|
|
losses = self._do_loss_init(train_batch)
|
|
|
|
all_accessed_keys = (
|
|
train_batch.accessed_keys
|
|
| dummy_batch.accessed_keys
|
|
| dummy_batch.added_keys
|
|
| set(self.model.view_requirements.keys())
|
|
)
|
|
|
|
TFPolicy._initialize_loss(
|
|
self,
|
|
losses,
|
|
[(k, v) for k, v in train_batch.items() if k in all_accessed_keys]
|
|
+ (
|
|
[(SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])]
|
|
if SampleBatch.SEQ_LENS in train_batch
|
|
else []
|
|
),
|
|
)
|
|
|
|
if "is_training" in self._loss_input_dict:
|
|
del self._loss_input_dict["is_training"]
|
|
|
|
# Call the grads stats fn.
|
|
# TODO: (sven) rename to simply stats_fn to match eager and torch.
|
|
self._stats_fetches.update(self.grad_stats_fn(train_batch, self._grads))
|
|
|
|
# Add new columns automatically to view-reqs.
|
|
if auto_remove_unneeded_view_reqs:
|
|
# Add those needed for postprocessing and training.
|
|
all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys
|
|
# Tag those only needed for post-processing (with some exceptions).
|
|
for key in dummy_batch.accessed_keys:
|
|
if (
|
|
key not in train_batch.accessed_keys
|
|
and key not in self.model.view_requirements
|
|
and key
|
|
not in [
|
|
SampleBatch.EPS_ID,
|
|
SampleBatch.AGENT_INDEX,
|
|
SampleBatch.UNROLL_ID,
|
|
SampleBatch.DONES,
|
|
SampleBatch.REWARDS,
|
|
SampleBatch.INFOS,
|
|
SampleBatch.OBS_EMBEDS,
|
|
]
|
|
):
|
|
if key in self.view_requirements:
|
|
self.view_requirements[key].used_for_training = False
|
|
if key in self._loss_input_dict:
|
|
del self._loss_input_dict[key]
|
|
# Remove those not needed at all (leave those that are needed
|
|
# by Sampler to properly execute sample collection).
|
|
# Also always leave DONES, REWARDS, and INFOS, no matter what.
|
|
for key in list(self.view_requirements.keys()):
|
|
if (
|
|
key not in all_accessed_keys
|
|
and key
|
|
not in [
|
|
SampleBatch.EPS_ID,
|
|
SampleBatch.AGENT_INDEX,
|
|
SampleBatch.UNROLL_ID,
|
|
SampleBatch.DONES,
|
|
SampleBatch.REWARDS,
|
|
SampleBatch.INFOS,
|
|
]
|
|
and key not in self.model.view_requirements
|
|
):
|
|
# If user deleted this key manually in postprocessing
|
|
# fn, warn about it and do not remove from
|
|
# view-requirements.
|
|
if key in dummy_batch.deleted_keys:
|
|
logger.warning(
|
|
"SampleBatch key '{}' was deleted manually in "
|
|
"postprocessing function! RLlib will "
|
|
"automatically remove non-used items from the "
|
|
"data stream. Remove the `del` from your "
|
|
"postprocessing function.".format(key)
|
|
)
|
|
# If we are not writing output to disk, safe to erase
|
|
# this key to save space in the sample batch.
|
|
elif self.config["output"] is None:
|
|
del self.view_requirements[key]
|
|
|
|
if key in self._loss_input_dict:
|
|
del self._loss_input_dict[key]
|
|
# Add those data_cols (again) that are missing and have
|
|
# dependencies by view_cols.
|
|
for key in list(self.view_requirements.keys()):
|
|
vr = self.view_requirements[key]
|
|
if (
|
|
vr.data_col is not None
|
|
and vr.data_col not in self.view_requirements
|
|
):
|
|
used_for_training = vr.data_col in train_batch.accessed_keys
|
|
self.view_requirements[vr.data_col] = ViewRequirement(
|
|
space=vr.space, used_for_training=used_for_training
|
|
)
|
|
|
|
self._loss_input_dict_no_rnn = {
|
|
k: v
|
|
for k, v in self._loss_input_dict.items()
|
|
if (v not in self._state_inputs and v != self._seq_lens)
|
|
}
|
|
|
|
def _do_loss_init(self, train_batch: SampleBatch):
|
|
losses = self.loss(self.model, self.dist_class, train_batch)
|
|
losses = force_list(losses)
|
|
self._stats_fetches.update(self.stats_fn(train_batch))
|
|
# Override the update ops to be those of the model.
|
|
self._update_ops = []
|
|
if not isinstance(self.model, tf.keras.Model):
|
|
self._update_ops = self.model.update_ops()
|
|
return losses
|
|
|
|
@override(TFPolicy)
|
|
@DeveloperAPI
|
|
def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
|
|
"""Creates a copy of self using existing input placeholders."""
|
|
|
|
flat_loss_inputs = tree.flatten(self._loss_input_dict)
|
|
flat_loss_inputs_no_rnn = tree.flatten(self._loss_input_dict_no_rnn)
|
|
|
|
# Note that there might be RNN state inputs at the end of the list
|
|
if len(flat_loss_inputs) != len(existing_inputs):
|
|
raise ValueError(
|
|
"Tensor list mismatch",
|
|
self._loss_input_dict,
|
|
self._state_inputs,
|
|
existing_inputs,
|
|
)
|
|
for i, v in enumerate(flat_loss_inputs_no_rnn):
|
|
if v.shape.as_list() != existing_inputs[i].shape.as_list():
|
|
raise ValueError(
|
|
"Tensor shape mismatch", i, v.shape, existing_inputs[i].shape
|
|
)
|
|
# By convention, the loss inputs are followed by state inputs and then
|
|
# the seq len tensor.
|
|
rnn_inputs = []
|
|
for i in range(len(self._state_inputs)):
|
|
rnn_inputs.append(
|
|
(
|
|
"state_in_{}".format(i),
|
|
existing_inputs[len(flat_loss_inputs_no_rnn) + i],
|
|
)
|
|
)
|
|
if rnn_inputs:
|
|
rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1]))
|
|
existing_inputs_unflattened = tree.unflatten_as(
|
|
self._loss_input_dict_no_rnn,
|
|
existing_inputs[: len(flat_loss_inputs_no_rnn)],
|
|
)
|
|
input_dict = OrderedDict(
|
|
[("is_exploring", self._is_exploring), ("timestep", self._timestep)]
|
|
+ [
|
|
(k, existing_inputs_unflattened[k])
|
|
for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
|
|
]
|
|
+ rnn_inputs
|
|
)
|
|
|
|
instance = self.__class__(
|
|
self.observation_space,
|
|
self.action_space,
|
|
self.config,
|
|
existing_inputs=input_dict,
|
|
existing_model=[
|
|
self.model,
|
|
# Deprecated: Target models should all reside under
|
|
# `policy.target_model` now.
|
|
("target_q_model", getattr(self, "target_q_model", None)),
|
|
("target_model", getattr(self, "target_model", None)),
|
|
],
|
|
)
|
|
|
|
instance._loss_input_dict = input_dict
|
|
losses = instance._do_loss_init(SampleBatch(input_dict))
|
|
loss_inputs = [
|
|
(k, existing_inputs_unflattened[k])
|
|
for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
|
|
]
|
|
|
|
TFPolicy._initialize_loss(instance, losses, loss_inputs)
|
|
instance._stats_fetches.update(
|
|
instance.grad_stats_fn(input_dict, instance._grads)
|
|
)
|
|
return instance
|
|
|
|
@override(Policy)
|
|
@DeveloperAPI
|
|
def get_initial_state(self) -> List[TensorType]:
|
|
if self.model:
|
|
return self.model.get_initial_state()
|
|
else:
|
|
return []
|
|
|
|
@override(Policy)
|
|
@DeveloperAPI
|
|
def load_batch_into_buffer(
|
|
self,
|
|
batch: SampleBatch,
|
|
buffer_index: int = 0,
|
|
) -> int:
|
|
# Set the is_training flag of the batch.
|
|
batch.set_training(True)
|
|
|
|
# Shortcut for 1 CPU only: Store batch in
|
|
# `self._loaded_single_cpu_batch`.
|
|
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
|
assert buffer_index == 0
|
|
self._loaded_single_cpu_batch = batch
|
|
return len(batch)
|
|
|
|
input_dict = self._get_loss_inputs_dict(batch, shuffle=False)
|
|
data_keys = tree.flatten(self._loss_input_dict_no_rnn)
|
|
if self._state_inputs:
|
|
state_keys = self._state_inputs + [self._seq_lens]
|
|
else:
|
|
state_keys = []
|
|
inputs = [input_dict[k] for k in data_keys]
|
|
state_inputs = [input_dict[k] for k in state_keys]
|
|
|
|
return self.multi_gpu_tower_stacks[buffer_index].load_data(
|
|
sess=self.get_session(),
|
|
inputs=inputs,
|
|
state_inputs=state_inputs,
|
|
)
|
|
|
|
@override(Policy)
|
|
@DeveloperAPI
|
|
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
|
|
# Shortcut for 1 CPU only: Batch should already be stored in
|
|
# `self._loaded_single_cpu_batch`.
|
|
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
|
assert buffer_index == 0
|
|
return (
|
|
len(self._loaded_single_cpu_batch)
|
|
if self._loaded_single_cpu_batch is not None
|
|
else 0
|
|
)
|
|
|
|
return self.multi_gpu_tower_stacks[buffer_index].num_tuples_loaded
|
|
|
|
@override(Policy)
|
|
@DeveloperAPI
|
|
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
|
|
# Shortcut for 1 CPU only: Batch should already be stored in
|
|
# `self._loaded_single_cpu_batch`.
|
|
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
|
|
assert buffer_index == 0
|
|
if self._loaded_single_cpu_batch is None:
|
|
raise ValueError(
|
|
"Must call Policy.load_batch_into_buffer() before "
|
|
"Policy.learn_on_loaded_batch()!"
|
|
)
|
|
# Get the correct slice of the already loaded batch to use,
|
|
# based on offset and batch size.
|
|
batch_size = self.config.get(
|
|
"sgd_minibatch_size", self.config["train_batch_size"]
|
|
)
|
|
if batch_size >= len(self._loaded_single_cpu_batch):
|
|
sliced_batch = self._loaded_single_cpu_batch
|
|
else:
|
|
sliced_batch = self._loaded_single_cpu_batch.slice(
|
|
start=offset, end=offset + batch_size
|
|
)
|
|
return self.learn_on_batch(sliced_batch)
|
|
|
|
return self.multi_gpu_tower_stacks[buffer_index].optimize(
|
|
self.get_session(), offset
|
|
)
|
|
|
|
@override(TFPolicy)
|
|
def gradients(self, optimizer, loss):
|
|
optimizers = force_list(optimizer)
|
|
losses = force_list(loss)
|
|
|
|
if is_overridden(self.compute_gradients_fn):
|
|
# New API: Allow more than one optimizer -> Return a list of
|
|
# lists of gradients.
|
|
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
|
return self.compute_gradients_fn(optimizers, losses)
|
|
# Old API: Return a single List of gradients.
|
|
else:
|
|
return self.compute_gradients_fn(optimizers[0], losses[0])
|
|
else:
|
|
return super().gradients(self, optimizers, losses)
|