2020-01-18 07:26:28 +01:00
|
|
|
from abc import ABCMeta, abstractmethod
|
2021-07-15 05:51:24 -04:00
|
|
|
from collections import namedtuple
|
2019-05-20 16:46:05 -07:00
|
|
|
import gym
|
2020-11-03 21:53:34 +01:00
|
|
|
from gym.spaces import Box
|
2020-11-12 16:27:34 +01:00
|
|
|
import logging
|
2020-01-18 07:26:28 +01:00
|
|
|
import numpy as np
|
2021-05-03 14:23:28 -07:00
|
|
|
from typing import Dict, List, Optional, TYPE_CHECKING
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-11-03 21:53:34 +01:00
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
2020-07-05 13:09:51 +02:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2020-10-01 16:57:10 +02:00
|
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
2021-06-15 13:08:43 +02:00
|
|
|
from ray.rllib.utils.deprecation import deprecation_warning
|
2020-02-11 00:22:07 +01:00
|
|
|
from ray.rllib.utils.exploration.exploration import Exploration
|
2021-04-30 19:26:30 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
2020-02-11 00:22:07 +01:00
|
|
|
from ray.rllib.utils.from_config import from_config
|
2021-06-30 12:32:11 +02:00
|
|
|
from ray.rllib.utils.spaces.space_utils import clip_action, \
|
2021-07-13 20:01:30 +02:00
|
|
|
get_base_struct_from_space, unbatch, unsquash_action
|
2020-08-15 13:24:22 +02:00
|
|
|
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
|
2020-07-05 13:09:51 +02:00
|
|
|
TensorType, TrainerConfigDict, Tuple, Union
|
2020-04-28 14:59:16 +02:00
|
|
|
|
2021-04-30 19:26:30 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-06-05 15:40:30 +02:00
|
|
|
torch, _ = try_import_torch()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-05-03 14:23:28 -07:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from ray.rllib.evaluation import MultiAgentEpisode
|
|
|
|
|
2020-11-12 16:27:34 +01:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
# By convention, metrics from optimizing the loss can be reported in the
|
|
|
|
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
|
|
|
|
LEARNER_STATS_KEY = "learner_stats"
|
|
|
|
|
2021-07-15 05:51:24 -04:00
|
|
|
# A policy spec used in the "config.multiagent.policies" specification dict
|
|
|
|
# as values (keys are the policy IDs (str)). E.g.:
|
|
|
|
# config:
|
|
|
|
# multiagent:
|
|
|
|
# policies: {
|
|
|
|
# "pol1": PolicySpec(None, Box, Discrete(2), {"lr": 0.0001}),
|
|
|
|
# "pol2": PolicySpec(config={"lr": 0.001}),
|
|
|
|
# }
|
|
|
|
PolicySpec = namedtuple(
|
|
|
|
"PolicySpec",
|
|
|
|
[
|
|
|
|
# If None, use the Trainer's default policy class stored under
|
|
|
|
# `Trainer._policy_class`.
|
|
|
|
"policy_class",
|
|
|
|
# If None, use the env's observation space. If None and there is no Env
|
|
|
|
# (e.g. offline RL), an error is thrown.
|
|
|
|
"observation_space",
|
|
|
|
# If None, use the env's action space. If None and there is no Env
|
|
|
|
# (e.g. offline RL), an error is thrown.
|
|
|
|
"action_space",
|
|
|
|
# Overrides defined keys in the main Trainer config.
|
|
|
|
# If None, use {}.
|
|
|
|
"config",
|
2021-07-22 10:59:13 -04:00
|
|
|
]) # defaults=(None, None, None, None)
|
|
|
|
# TODO: From 3.7 on, we could pass `defaults` into the above constructor.
|
|
|
|
# We still support py3.6.
|
2021-07-15 05:51:24 -04:00
|
|
|
PolicySpec.__new__.__defaults__ = (None, None, None, None)
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@DeveloperAPI
|
2020-01-18 07:26:28 +01:00
|
|
|
class Policy(metaclass=ABCMeta):
|
2019-05-20 16:46:05 -07:00
|
|
|
"""An agent policy and loss, i.e., a TFPolicy or other subclass.
|
|
|
|
|
|
|
|
This object defines how to act in the environment, and also losses used to
|
|
|
|
improve the policy based on its experiences. Note that both policy and
|
|
|
|
loss are defined together for convenience, though the policy itself is
|
|
|
|
logically separate.
|
|
|
|
|
|
|
|
All policies can directly extend Policy, however TensorFlow users may
|
|
|
|
find TFPolicy simpler to implement. TFPolicy also enables RLlib
|
|
|
|
to apply TensorFlow-specific optimizations such as fusing multiple policy
|
|
|
|
graphs and multi-GPU support.
|
|
|
|
|
|
|
|
Attributes:
|
2020-08-06 18:29:04 -07:00
|
|
|
observation_space (gym.Space): Observation space of the policy. For
|
|
|
|
complex spaces (e.g., Dict), this will be flattened version of the
|
|
|
|
space, and you can access the original space via
|
|
|
|
``observation_space.original_space``.
|
2019-05-20 16:46:05 -07:00
|
|
|
action_space (gym.Space): Action space of the policy.
|
2020-04-27 23:19:26 +02:00
|
|
|
exploration (Exploration): The exploration object to use for
|
|
|
|
computing actions, or None.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-08-07 16:49:49 -07:00
|
|
|
def __init__(self, observation_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space, config: TrainerConfigDict):
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Initialize the graph.
|
|
|
|
|
|
|
|
This is the standard constructor for policies. The policy
|
2019-06-03 06:49:24 +08:00
|
|
|
class you pass into RolloutWorker will be constructed with
|
2019-05-20 16:46:05 -07:00
|
|
|
these arguments.
|
|
|
|
|
|
|
|
Args:
|
2020-07-05 13:09:51 +02:00
|
|
|
observation_space (gym.spaces.Space): Observation space of the
|
|
|
|
policy.
|
|
|
|
action_space (gym.spaces.Space): Action space of the policy.
|
|
|
|
config (TrainerConfigDict): Policy-specific configuration data.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
|
|
|
self.observation_space = observation_space
|
|
|
|
self.action_space = action_space
|
2020-04-28 14:59:16 +02:00
|
|
|
self.action_space_struct = get_base_struct_from_space(action_space)
|
2020-01-18 07:26:28 +01:00
|
|
|
self.config = config
|
2020-11-18 15:39:23 +01:00
|
|
|
if self.config.get("callbacks"):
|
|
|
|
self.callbacks: "DefaultCallbacks" = self.config.get("callbacks")()
|
|
|
|
else:
|
|
|
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
|
|
|
self.callbacks: "DefaultCallbacks" = DefaultCallbacks()
|
2020-02-11 00:22:07 +01:00
|
|
|
# The global timestep, broadcast down from time to time from the
|
|
|
|
# driver.
|
|
|
|
self.global_timestep = 0
|
2020-04-01 09:43:21 +02:00
|
|
|
# The action distribution class to use for action sampling, if any.
|
|
|
|
# Child classes may set this.
|
|
|
|
self.dist_class = None
|
2020-11-12 16:27:34 +01:00
|
|
|
# Maximal view requirements dict for `learn_on_batch()` and
|
|
|
|
# `compute_actions` calls.
|
|
|
|
# View requirements will be automatically filtered out later based
|
|
|
|
# on the postprocessing and loss functions to ensure optimal data
|
|
|
|
# collection and transfer performance.
|
|
|
|
view_reqs = self._get_default_view_requirements()
|
|
|
|
if not hasattr(self, "view_requirements"):
|
|
|
|
self.view_requirements = view_reqs
|
|
|
|
else:
|
2021-01-13 08:53:34 +01:00
|
|
|
for k, v in view_reqs.items():
|
|
|
|
if k not in self.view_requirements:
|
|
|
|
self.view_requirements[k] = v
|
2020-12-07 13:08:17 +01:00
|
|
|
self._model_init_state_automatically_added = False
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
@abstractmethod
|
2019-05-20 16:46:05 -07:00
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def compute_actions(
|
|
|
|
self,
|
|
|
|
obs_batch: Union[List[TensorType], TensorType],
|
|
|
|
state_batches: Optional[List[TensorType]] = None,
|
|
|
|
prev_action_batch: Union[List[TensorType], TensorType] = None,
|
|
|
|
prev_reward_batch: Union[List[TensorType], TensorType] = None,
|
|
|
|
info_batch: Optional[Dict[str, list]] = None,
|
|
|
|
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
|
|
|
explore: Optional[bool] = None,
|
|
|
|
timestep: Optional[int] = None,
|
|
|
|
**kwargs) -> \
|
|
|
|
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
2020-01-18 07:26:28 +01:00
|
|
|
"""Computes actions for the current policy.
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
Args:
|
2020-07-05 13:09:51 +02:00
|
|
|
obs_batch (Union[List[TensorType], TensorType]): Batch of
|
|
|
|
observations.
|
|
|
|
state_batches (Optional[List[TensorType]]): List of RNN state input
|
|
|
|
batches, if any.
|
|
|
|
prev_action_batch (Union[List[TensorType], TensorType]): Batch of
|
|
|
|
previous action values.
|
|
|
|
prev_reward_batch (Union[List[TensorType], TensorType]): Batch of
|
|
|
|
previous rewards.
|
|
|
|
info_batch (Optional[Dict[str, list]]): Batch of info objects.
|
|
|
|
episodes (Optional[List[MultiAgentEpisode]] ): List of
|
|
|
|
MultiAgentEpisode, one for each obs in obs_batch. This provides
|
|
|
|
access to all of the internal episode state, which may be
|
|
|
|
useful for model-based or multiagent algorithms.
|
|
|
|
explore (Optional[bool]): Whether to pick an exploitation or
|
|
|
|
exploration action. Set to None (default) for using the
|
|
|
|
value of `self.config["explore"]`.
|
|
|
|
timestep (Optional[int]): The current (sampling) time step.
|
|
|
|
|
|
|
|
Keyword Args:
|
2019-05-20 16:46:05 -07:00
|
|
|
kwargs: forward compatibility placeholder
|
|
|
|
|
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
Tuple:
|
|
|
|
actions (TensorType): Batch of output actions, with shape like
|
|
|
|
[BATCH_SIZE, ACTION_SHAPE].
|
|
|
|
state_outs (List[TensorType]): List of RNN state output
|
|
|
|
batches, if any, with shape like [STATE_SIZE, BATCH_SIZE].
|
|
|
|
info (List[dict]): Dictionary of extra feature batches, if any,
|
|
|
|
with shape like
|
|
|
|
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def compute_single_action(
|
|
|
|
self,
|
|
|
|
obs: TensorType,
|
|
|
|
state: Optional[List[TensorType]] = None,
|
|
|
|
prev_action: Optional[TensorType] = None,
|
|
|
|
prev_reward: Optional[TensorType] = None,
|
|
|
|
info: dict = None,
|
|
|
|
episode: Optional["MultiAgentEpisode"] = None,
|
2021-06-30 12:32:11 +02:00
|
|
|
clip_actions: bool = None,
|
2020-07-05 13:09:51 +02:00
|
|
|
explore: Optional[bool] = None,
|
|
|
|
timestep: Optional[int] = None,
|
2021-07-13 20:01:30 +02:00
|
|
|
unsquash_actions: bool = None,
|
2020-07-05 13:09:51 +02:00
|
|
|
**kwargs) -> \
|
|
|
|
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Unbatched version of compute_actions.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
|
|
|
obs (TensorType): Single observation.
|
|
|
|
state (Optional[List[TensorType]]): List of RNN state inputs, if
|
|
|
|
any.
|
|
|
|
prev_action (Optional[TensorType]): Previous action value, if any.
|
|
|
|
prev_reward (Optional[TensorType]): Previous reward, if any.
|
|
|
|
info (dict): Info object, if any.
|
|
|
|
episode (Optional[MultiAgentEpisode]): this provides access to all
|
|
|
|
of the internal episode state, which may be useful for
|
|
|
|
model-based or multi-agent algorithms.
|
2021-07-13 20:01:30 +02:00
|
|
|
unsquash_actions (bool): Should actions be unsquashed according to
|
2021-06-30 12:32:11 +02:00
|
|
|
the Policy's action space?
|
|
|
|
clip_actions (bool): Should actions be clipped according to the
|
|
|
|
Policy's action space?
|
2020-07-05 13:09:51 +02:00
|
|
|
explore (Optional[bool]): Whether to pick an exploitation or
|
|
|
|
exploration action
|
|
|
|
(default: None -> use self.config["explore"]).
|
|
|
|
timestep (Optional[int]): The current (sampling) time step.
|
|
|
|
|
|
|
|
Keyword Args:
|
2020-07-08 16:12:20 +02:00
|
|
|
kwargs: Forward compatibility.
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
Tuple:
|
2020-07-08 16:12:20 +02:00
|
|
|
- actions (TensorType): Single action.
|
|
|
|
- state_outs (List[TensorType]): List of RNN state outputs,
|
2020-07-05 13:09:51 +02:00
|
|
|
if any.
|
2020-07-08 16:12:20 +02:00
|
|
|
- info (dict): Dictionary of extra features, if any.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2021-07-13 20:01:30 +02:00
|
|
|
# If policy works in normalized space, we should unsquash the action.
|
|
|
|
# Use value of config.normalize_actions, if None.
|
|
|
|
unsquash_actions = \
|
|
|
|
unsquash_actions if unsquash_actions is not None \
|
2021-06-30 12:32:11 +02:00
|
|
|
else self.config["normalize_actions"]
|
|
|
|
clip_actions = clip_actions if clip_actions is not None else \
|
|
|
|
self.config["clip_actions"]
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
prev_action_batch = None
|
|
|
|
prev_reward_batch = None
|
|
|
|
info_batch = None
|
|
|
|
episodes = None
|
2020-01-23 03:42:52 -05:00
|
|
|
state_batch = None
|
2019-05-20 16:46:05 -07:00
|
|
|
if prev_action is not None:
|
|
|
|
prev_action_batch = [prev_action]
|
|
|
|
if prev_reward is not None:
|
|
|
|
prev_reward_batch = [prev_reward]
|
|
|
|
if info is not None:
|
|
|
|
info_batch = [info]
|
|
|
|
if episode is not None:
|
|
|
|
episodes = [episode]
|
2020-01-23 03:42:52 -05:00
|
|
|
if state is not None:
|
2020-06-05 15:40:30 +02:00
|
|
|
state_batch = [
|
2020-08-07 16:49:49 -07:00
|
|
|
s.unsqueeze(0)
|
|
|
|
if torch and isinstance(s, torch.Tensor) else np.expand_dims(
|
|
|
|
s, 0) for s in state
|
2020-06-05 15:40:30 +02:00
|
|
|
]
|
2020-01-18 07:26:28 +01:00
|
|
|
|
2020-06-13 17:51:50 +02:00
|
|
|
out = self.compute_actions(
|
2020-01-25 22:36:43 -08:00
|
|
|
[obs],
|
|
|
|
state_batch,
|
2019-05-20 16:46:05 -07:00
|
|
|
prev_action_batch=prev_action_batch,
|
|
|
|
prev_reward_batch=prev_reward_batch,
|
|
|
|
info_batch=info_batch,
|
2020-02-11 00:22:07 +01:00
|
|
|
episodes=episodes,
|
|
|
|
explore=explore,
|
|
|
|
timestep=timestep)
|
|
|
|
|
2020-06-13 17:51:50 +02:00
|
|
|
# Some policies don't return a tuple, but always just a single action.
|
|
|
|
# E.g. ES and ARS.
|
|
|
|
if not isinstance(out, tuple):
|
|
|
|
single_action = out
|
|
|
|
state_out = []
|
|
|
|
info = {}
|
|
|
|
# Normal case: Policy should return (action, state, info) tuple.
|
|
|
|
else:
|
|
|
|
batched_action, state_out, info = out
|
|
|
|
single_action = unbatch(batched_action)
|
2020-05-20 22:29:08 +02:00
|
|
|
assert len(single_action) == 1
|
|
|
|
single_action = single_action[0]
|
|
|
|
|
2021-07-13 20:01:30 +02:00
|
|
|
# If we work in normalized action space (normalize_actions=True),
|
|
|
|
# we re-translate here into the env's action space.
|
|
|
|
if unsquash_actions:
|
|
|
|
single_action = unsquash_action(single_action,
|
|
|
|
self.action_space_struct)
|
|
|
|
# Clip, according to env's action space.
|
2021-06-30 12:32:11 +02:00
|
|
|
elif clip_actions:
|
2020-05-20 22:29:08 +02:00
|
|
|
single_action = clip_action(single_action,
|
|
|
|
self.action_space_struct)
|
2020-01-18 07:26:28 +01:00
|
|
|
|
|
|
|
# Return action, internal state(s), infos.
|
2020-05-20 22:29:08 +02:00
|
|
|
return single_action, [s[0] for s in state_out], \
|
2019-05-20 16:46:05 -07:00
|
|
|
{k: v[0] for k, v in info.items()}
|
|
|
|
|
2020-11-03 21:53:34 +01:00
|
|
|
@DeveloperAPI
|
2020-08-06 10:54:20 +02:00
|
|
|
def compute_actions_from_input_dict(
|
2020-06-30 05:33:19 +02:00
|
|
|
self,
|
2020-08-06 10:54:20 +02:00
|
|
|
input_dict: Dict[str, TensorType],
|
2020-06-30 05:33:19 +02:00
|
|
|
explore: bool = None,
|
|
|
|
timestep: Optional[int] = None,
|
2020-11-03 21:53:34 +01:00
|
|
|
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
2020-07-05 13:09:51 +02:00
|
|
|
**kwargs) -> \
|
|
|
|
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
2020-08-06 10:54:20 +02:00
|
|
|
"""Computes actions from collected samples (across multiple-agents).
|
2020-06-30 05:33:19 +02:00
|
|
|
|
2020-08-06 10:54:20 +02:00
|
|
|
Uses the currently "forward-pass-registered" samples from the collector
|
|
|
|
to construct the input_dict for the Model.
|
2020-06-30 05:33:19 +02:00
|
|
|
|
|
|
|
Args:
|
2020-08-06 10:54:20 +02:00
|
|
|
input_dict (Dict[str, TensorType]): An input dict mapping str
|
|
|
|
keys to Tensors. `input_dict` already abides to the Policy's
|
|
|
|
as well as the Model's view requirements and can be passed
|
|
|
|
to the Model as-is.
|
2020-06-30 05:33:19 +02:00
|
|
|
explore (bool): Whether to pick an exploitation or exploration
|
|
|
|
action (default: None -> use self.config["explore"]).
|
|
|
|
timestep (Optional[int]): The current (sampling) time step.
|
|
|
|
kwargs: forward compatibility placeholder
|
|
|
|
|
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
Tuple:
|
|
|
|
actions (TensorType): Batch of output actions, with shape
|
|
|
|
like [BATCH_SIZE, ACTION_SHAPE].
|
|
|
|
state_outs (List[TensorType]): List of RNN state output
|
|
|
|
batches, if any, with shape like [STATE_SIZE, BATCH_SIZE].
|
|
|
|
info (dict): Dictionary of extra feature batches, if any, with
|
|
|
|
shape like
|
|
|
|
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
2020-06-30 05:33:19 +02:00
|
|
|
"""
|
2020-10-01 16:57:10 +02:00
|
|
|
# Default implementation just passes obs, prev-a/r, and states on to
|
|
|
|
# `self.compute_actions()`.
|
|
|
|
state_batches = [
|
2021-03-23 17:50:18 +01:00
|
|
|
s for k, s in input_dict.items() if k[:9] == "state_in_"
|
2020-10-01 16:57:10 +02:00
|
|
|
]
|
|
|
|
return self.compute_actions(
|
|
|
|
input_dict[SampleBatch.OBS],
|
|
|
|
state_batches,
|
|
|
|
prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
|
|
|
|
prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
|
2021-01-07 21:25:02 +01:00
|
|
|
info_batch=input_dict.get(SampleBatch.INFOS),
|
2020-10-01 16:57:10 +02:00
|
|
|
explore=explore,
|
|
|
|
timestep=timestep,
|
2020-11-03 21:53:34 +01:00
|
|
|
episodes=episodes,
|
2020-10-01 16:57:10 +02:00
|
|
|
**kwargs,
|
|
|
|
)
|
2020-06-30 05:33:19 +02:00
|
|
|
|
2020-02-22 23:19:49 +01:00
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def compute_log_likelihoods(
|
|
|
|
self,
|
|
|
|
actions: Union[List[TensorType], TensorType],
|
|
|
|
obs_batch: Union[List[TensorType], TensorType],
|
|
|
|
state_batches: Optional[List[TensorType]] = None,
|
2020-08-07 16:49:49 -07:00
|
|
|
prev_action_batch: Optional[Union[List[TensorType],
|
|
|
|
TensorType]] = None,
|
2021-07-13 20:01:30 +02:00
|
|
|
prev_reward_batch: Optional[Union[List[TensorType],
|
|
|
|
TensorType]] = None,
|
|
|
|
actions_normalized: bool = True,
|
|
|
|
) -> TensorType:
|
2020-02-22 23:19:49 +01:00
|
|
|
"""Computes the log-prob/likelihood for a given action and observation.
|
|
|
|
|
|
|
|
Args:
|
2020-07-05 13:09:51 +02:00
|
|
|
actions (Union[List[TensorType], TensorType]): Batch of actions,
|
|
|
|
for which to retrieve the log-probs/likelihoods (given all
|
|
|
|
other inputs: obs, states, ..).
|
|
|
|
obs_batch (Union[List[TensorType], TensorType]): Batch of
|
|
|
|
observations.
|
|
|
|
state_batches (Optional[List[TensorType]]): List of RNN state input
|
|
|
|
batches, if any.
|
|
|
|
prev_action_batch (Optional[Union[List[TensorType], TensorType]]):
|
|
|
|
Batch of previous action values.
|
|
|
|
prev_reward_batch (Optional[Union[List[TensorType], TensorType]]):
|
|
|
|
Batch of previous rewards.
|
2021-07-13 20:01:30 +02:00
|
|
|
actions_normalized (bool): Is the given `actions` already
|
|
|
|
normalized (between -1.0 and 1.0) or not? If not and
|
|
|
|
`normalize_actions=True`, we need to normalize the given
|
|
|
|
actions first, before calculating log likelihoods.
|
2020-02-22 23:19:49 +01:00
|
|
|
|
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
TensorType: Batch of log probs/likelihoods, with shape:
|
|
|
|
[BATCH_SIZE].
|
2020-02-22 23:19:49 +01:00
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def postprocess_trajectory(
|
|
|
|
self,
|
|
|
|
sample_batch: SampleBatch,
|
2020-08-07 16:49:49 -07:00
|
|
|
other_agent_batches: Optional[Dict[AgentID, Tuple[
|
|
|
|
"Policy", SampleBatch]]] = None,
|
2020-07-05 13:09:51 +02:00
|
|
|
episode: Optional["MultiAgentEpisode"] = None) -> SampleBatch:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Implements algorithm-specific trajectory postprocessing.
|
|
|
|
|
|
|
|
This will be called on each trajectory fragment computed during policy
|
|
|
|
evaluation. Each fragment is guaranteed to be only from one episode.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
2019-05-20 16:46:05 -07:00
|
|
|
sample_batch (SampleBatch): batch of experiences for the policy,
|
|
|
|
which will contain at most one episode trajectory.
|
|
|
|
other_agent_batches (dict): In a multi-agent env, this contains a
|
|
|
|
mapping of agent ids to (policy, agent_batch) tuples
|
2019-08-08 14:03:28 -07:00
|
|
|
containing the policy and experiences of the other agents.
|
2020-07-05 13:09:51 +02:00
|
|
|
episode (Optional[MultiAgentEpisode]): An optional multi-agent
|
|
|
|
episode object to provide access to all of the
|
2019-05-20 16:46:05 -07:00
|
|
|
internal episode state, which may be useful for model-based or
|
|
|
|
multi-agent algorithms.
|
|
|
|
|
|
|
|
Returns:
|
2020-01-18 07:26:28 +01:00
|
|
|
SampleBatch: Postprocessed sample batch.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
|
|
|
return sample_batch
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Fused compute gradients and apply gradients call.
|
|
|
|
|
|
|
|
Either this or the combination of compute/apply grads must be
|
|
|
|
implemented by subclasses.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
|
|
|
samples (SampleBatch): The SampleBatch object to learn from.
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
Dict[str, TensorType]: Dictionary of extra metadata from
|
|
|
|
compute_gradients().
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
Examples:
|
2020-07-05 13:09:51 +02:00
|
|
|
>>> sample_batch = ev.sample()
|
|
|
|
>>> ev.learn_on_batch(sample_batch)
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
grads, grad_info = self.compute_gradients(samples)
|
|
|
|
self.apply_gradients(grads)
|
|
|
|
return grad_info
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def compute_gradients(self, postprocessed_batch: SampleBatch) -> \
|
|
|
|
Tuple[ModelGradients, Dict[str, TensorType]]:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Computes gradients against a batch of experiences.
|
|
|
|
|
|
|
|
Either this or learn_on_batch() must be implemented by subclasses.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
|
|
|
postprocessed_batch (SampleBatch): The SampleBatch object to use
|
|
|
|
for calculating gradients.
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
Tuple[ModelGradients, Dict[str, TensorType]]:
|
|
|
|
- List of gradient output values.
|
|
|
|
- Extra policy-specific info values.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def apply_gradients(self, gradients: ModelGradients) -> None:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Applies previously computed gradients.
|
|
|
|
|
|
|
|
Either this or learn_on_batch() must be implemented by subclasses.
|
2020-07-05 13:09:51 +02:00
|
|
|
|
|
|
|
Args:
|
|
|
|
gradients (ModelGradients): The already calculated gradients to
|
|
|
|
apply to this Policy.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def get_weights(self) -> ModelWeights:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Returns model weights.
|
|
|
|
|
2021-07-19 13:16:03 -04:00
|
|
|
Note: The return value of this method will reside under the "weights"
|
|
|
|
key in the return value of Policy.get_state().
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
ModelWeights: Serializable copy or view of model weights.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2020-01-21 08:06:50 +01:00
|
|
|
raise NotImplementedError
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def set_weights(self, weights: ModelWeights) -> None:
|
2021-07-19 13:16:03 -04:00
|
|
|
"""Sets this Policy's model's weights.
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
|
|
|
weights (ModelWeights): Serializable copy or view of model weights.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2020-01-21 08:06:50 +01:00
|
|
|
raise NotImplementedError
|
2020-01-18 07:26:28 +01:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2021-06-15 13:08:43 +02:00
|
|
|
def get_exploration_state(self) -> Dict[str, TensorType]:
|
2020-02-11 00:22:07 +01:00
|
|
|
"""Returns the current exploration information of this policy.
|
|
|
|
|
|
|
|
This information depends on the policy's Exploration object.
|
|
|
|
|
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
Dict[str, TensorType]: Serializable information on the
|
|
|
|
`self.exploration` object.
|
2020-02-11 00:22:07 +01:00
|
|
|
"""
|
2021-06-15 13:08:43 +02:00
|
|
|
return self.exploration.get_state()
|
|
|
|
|
|
|
|
# TODO: (sven) Deprecate this method.
|
|
|
|
def get_exploration_info(self) -> Dict[str, TensorType]:
|
|
|
|
deprecation_warning("get_exploration_info", "get_exploration_state")
|
|
|
|
return self.get_exploration_state()
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def is_recurrent(self) -> bool:
|
2020-02-11 00:22:07 +01:00
|
|
|
"""Whether this Policy holds a recurrent Model.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: True if this Policy has-a RNN-based Model.
|
|
|
|
"""
|
2020-07-05 13:09:51 +02:00
|
|
|
return False
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def num_state_tensors(self) -> int:
|
2020-02-11 00:22:07 +01:00
|
|
|
"""The number of internal states needed by the RNN-Model of the Policy.
|
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
Returns:
|
2020-02-11 00:22:07 +01:00
|
|
|
int: The number of RNN internal states kept by this Policy's Model.
|
2020-01-18 07:26:28 +01:00
|
|
|
"""
|
|
|
|
return 0
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def get_initial_state(self) -> List[TensorType]:
|
|
|
|
"""Returns initial RNN state for the current policy.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[TensorType]: Initial RNN state for the current policy.
|
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
return []
|
|
|
|
|
2021-07-20 14:58:13 -04:00
|
|
|
@DeveloperAPI
|
|
|
|
def load_batch_into_buffer(self, batch: SampleBatch,
|
|
|
|
buffer_index: int = 0) -> int:
|
|
|
|
"""Bulk-loads the given SampleBatch into the devices' memories.
|
|
|
|
|
|
|
|
The data is split equally across all the devices. If the data is not
|
|
|
|
evenly divisible by the batch size, excess data should be discarded.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch (SampleBatch): The SampleBatch to load.
|
|
|
|
buffer_index (int): The index of the buffer (a MultiGPUTowerStack)
|
|
|
|
to use on the devices.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: The number of tuples loaded per device.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
|
|
|
|
"""Returns the number of currently loaded samples in the given buffer.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch (SampleBatch): The SampleBatch to load.
|
|
|
|
buffer_index (int): The index of the buffer (a MultiGPUTowerStack)
|
|
|
|
to use on the devices.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: The number of tuples loaded per device.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
|
|
|
|
"""Runs a single step of SGD on already loaded data in a buffer.
|
|
|
|
|
|
|
|
Runs an SGD step over a slice of the pre-loaded batch, offset by
|
|
|
|
the `offset` argument (useful for performing n minibatch SGD
|
|
|
|
updates repeatedly on the same, already pre-loaded data).
|
|
|
|
|
|
|
|
Updates shared model weights based on the averaged per-device
|
|
|
|
gradients.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
offset (int): Offset into the preloaded data. Used for pre-loading
|
|
|
|
a train-batch once to a device, then iterating over
|
|
|
|
(subsampling through) this batch n times doing minibatch SGD.
|
|
|
|
buffer_index (int): The index of the buffer (a MultiGPUTowerStack)
|
|
|
|
to take the already pre-loaded data from.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The outputs of extra_ops evaluated over the batch.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
2021-06-15 13:08:43 +02:00
|
|
|
"""Returns all local state.
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-07-19 13:16:03 -04:00
|
|
|
Note: Not to be confused with an RNN model's internal state.
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
Union[Dict[str, TensorType], List[TensorType]]: Serialized local
|
|
|
|
state.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2021-06-15 13:08:43 +02:00
|
|
|
state = {
|
2021-07-19 13:16:03 -04:00
|
|
|
# All the policy's weights.
|
2021-06-15 13:08:43 +02:00
|
|
|
"weights": self.get_weights(),
|
2021-07-19 13:16:03 -04:00
|
|
|
# The current global timestep.
|
2021-06-15 13:08:43 +02:00
|
|
|
"global_timestep": self.global_timestep,
|
|
|
|
}
|
|
|
|
return state
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def set_state(self, state: object) -> None:
|
2021-06-15 13:08:43 +02:00
|
|
|
"""Restores all local state to the provided `state`.
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
2021-06-15 13:08:43 +02:00
|
|
|
state (object): The new state to set this policy to. Can be
|
|
|
|
obtained by calling `Policy.get_state()`.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2021-06-15 13:08:43 +02:00
|
|
|
self.set_weights(state["weights"])
|
|
|
|
self.global_timestep = state["global_timestep"]
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Called on an update to global vars.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
|
|
|
global_vars (Dict[str, TensorType]): Global variables by str key,
|
|
|
|
broadcast from the driver.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2020-02-11 00:22:07 +01:00
|
|
|
# Store the current global time step (sum over all policies' sample
|
|
|
|
# steps).
|
|
|
|
self.global_timestep = global_vars["timestep"]
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2021-07-13 18:38:11 +02:00
|
|
|
def export_model(self, export_dir: str,
|
|
|
|
onnx: Optional[int] = None) -> None:
|
2021-02-22 17:09:40 +01:00
|
|
|
"""Exports the Policy's Model to local directory for serving.
|
|
|
|
|
|
|
|
Note: The file format will depend on the deep learning framework used.
|
|
|
|
See the child classed of Policy and their `export_model`
|
|
|
|
implementations for more details.
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
2019-05-20 16:46:05 -07:00
|
|
|
export_dir (str): Local writable directory.
|
2021-07-13 18:38:11 +02:00
|
|
|
onnx (int): If given, will export model in ONNX format. The
|
|
|
|
value of this parameter set the ONNX OpSet version to use.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2020-03-23 20:19:30 +01:00
|
|
|
@DeveloperAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def import_model_from_h5(self, import_file: str) -> None:
|
2020-03-23 20:19:30 +01:00
|
|
|
"""Imports Policy from local file.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
2020-03-23 20:19:30 +01:00
|
|
|
import_file (str): Local readable file.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2021-07-19 13:16:03 -04:00
|
|
|
@DeveloperAPI
|
|
|
|
def get_session(self) -> Optional["tf1.Session"]:
|
|
|
|
"""Returns tf.Session object to use for computing actions or None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Optional[tf1.Session]: The tf Session to use for computing actions
|
|
|
|
and losses with this policy.
|
|
|
|
"""
|
|
|
|
return None
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
def _create_exploration(self) -> Exploration:
|
2020-02-19 21:18:45 +01:00
|
|
|
"""Creates the Policy's Exploration object.
|
|
|
|
|
|
|
|
This method only exists b/c some Trainers do not use TfPolicy nor
|
|
|
|
TorchPolicy, but inherit directly from Policy. Others inherit from
|
|
|
|
TfPolicy w/o using DynamicTfPolicy.
|
2020-07-05 13:09:51 +02:00
|
|
|
TODO(sven): unify these cases.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Exploration: The Exploration object to be used by this Policy.
|
|
|
|
"""
|
2020-04-01 09:43:21 +02:00
|
|
|
if getattr(self, "exploration", None) is not None:
|
|
|
|
return self.exploration
|
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
exploration = from_config(
|
|
|
|
Exploration,
|
2020-04-01 09:43:21 +02:00
|
|
|
self.config.get("exploration_config",
|
|
|
|
{"type": "StochasticSampling"}),
|
|
|
|
action_space=self.action_space,
|
|
|
|
policy_config=self.config,
|
|
|
|
model=getattr(self, "model", None),
|
|
|
|
num_workers=self.config.get("num_workers", 0),
|
|
|
|
worker_index=self.config.get("worker_index", 0),
|
2021-01-13 08:53:34 +01:00
|
|
|
framework=getattr(self, "framework",
|
|
|
|
self.config.get("framework", "tf")))
|
2020-02-19 21:18:45 +01:00
|
|
|
return exploration
|
|
|
|
|
2020-11-03 21:53:34 +01:00
|
|
|
def _get_default_view_requirements(self):
|
|
|
|
"""Returns a default ViewRequirements dict.
|
|
|
|
|
|
|
|
Note: This is the base/maximum requirement dict, from which later
|
|
|
|
some requirements will be subtracted again automatically to streamline
|
|
|
|
data collection, batch creation, and data transfer.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ViewReqDict: The default view requirements dict.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Default view requirements (equal to those that we would use before
|
|
|
|
# the trajectory view API was introduced).
|
|
|
|
return {
|
2021-07-25 19:25:07 +01:00
|
|
|
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
|
2020-11-03 21:53:34 +01:00
|
|
|
SampleBatch.NEXT_OBS: ViewRequirement(
|
|
|
|
data_col=SampleBatch.OBS,
|
2020-12-07 13:08:17 +01:00
|
|
|
shift=1,
|
2020-11-03 21:53:34 +01:00
|
|
|
space=self.observation_space),
|
2021-07-25 19:25:07 +01:00
|
|
|
SampleBatch.ACTIONS: ViewRequirement(
|
|
|
|
space=self.action_space, used_for_compute_actions=False),
|
2020-12-07 13:08:17 +01:00
|
|
|
# For backward compatibility with custom Models that don't specify
|
|
|
|
# these explicitly (will be removed by Policy if not used).
|
|
|
|
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
|
|
|
data_col=SampleBatch.ACTIONS,
|
|
|
|
shift=-1,
|
|
|
|
space=self.action_space),
|
2020-11-03 21:53:34 +01:00
|
|
|
SampleBatch.REWARDS: ViewRequirement(),
|
2020-12-07 13:08:17 +01:00
|
|
|
# For backward compatibility with custom Models that don't specify
|
|
|
|
# these explicitly (will be removed by Policy if not used).
|
|
|
|
SampleBatch.PREV_REWARDS: ViewRequirement(
|
|
|
|
data_col=SampleBatch.REWARDS, shift=-1),
|
2020-11-03 21:53:34 +01:00
|
|
|
SampleBatch.DONES: ViewRequirement(),
|
|
|
|
SampleBatch.INFOS: ViewRequirement(),
|
|
|
|
SampleBatch.EPS_ID: ViewRequirement(),
|
2020-12-01 08:21:45 +01:00
|
|
|
SampleBatch.UNROLL_ID: ViewRequirement(),
|
2020-11-03 21:53:34 +01:00
|
|
|
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
|
|
|
"t": ViewRequirement(),
|
|
|
|
}
|
|
|
|
|
|
|
|
def _initialize_loss_from_dummy_batch(
|
2020-11-12 16:27:34 +01:00
|
|
|
self,
|
|
|
|
auto_remove_unneeded_view_reqs: bool = True,
|
|
|
|
stats_fn=None,
|
|
|
|
) -> None:
|
2020-11-03 21:53:34 +01:00
|
|
|
"""Performs test calls through policy's model and loss.
|
|
|
|
|
|
|
|
NOTE: This base method should work for define-by-run Policies such as
|
|
|
|
torch and tf-eager policies.
|
|
|
|
|
|
|
|
If required, will thereby detect automatically, which data views are
|
|
|
|
required by a) the forward pass, b) the postprocessing, and c) the loss
|
|
|
|
functions, and remove those from self.view_requirements that are not
|
|
|
|
necessary for these computations (to save data storage and transfer).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
auto_remove_unneeded_view_reqs (bool): Whether to automatically
|
|
|
|
remove those ViewRequirements records from
|
|
|
|
self.view_requirements that are not needed.
|
2020-11-12 16:27:34 +01:00
|
|
|
stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str,
|
|
|
|
TensorType]]]): An optional stats function to be called after
|
|
|
|
the loss.
|
2020-11-03 21:53:34 +01:00
|
|
|
"""
|
2020-11-12 16:27:34 +01:00
|
|
|
sample_batch_size = max(self.batch_divisibility_req * 4, 32)
|
2020-11-03 21:53:34 +01:00
|
|
|
self._dummy_batch = self._get_dummy_batch_from_view_requirements(
|
|
|
|
sample_batch_size)
|
2021-04-11 18:20:04 +02:00
|
|
|
self._lazy_tensor_dict(self._dummy_batch)
|
2020-11-03 21:53:34 +01:00
|
|
|
actions, state_outs, extra_outs = \
|
2021-04-11 18:20:04 +02:00
|
|
|
self.compute_actions_from_input_dict(
|
|
|
|
self._dummy_batch, explore=False)
|
2020-11-12 16:27:34 +01:00
|
|
|
# Add all extra action outputs to view reqirements (these may be
|
|
|
|
# filtered out later again, if not needed for postprocessing or loss).
|
2020-11-03 21:53:34 +01:00
|
|
|
for key, value in extra_outs.items():
|
2021-04-11 18:20:04 +02:00
|
|
|
self._dummy_batch[key] = value
|
2020-11-03 21:53:34 +01:00
|
|
|
if key not in self.view_requirements:
|
|
|
|
self.view_requirements[key] = \
|
|
|
|
ViewRequirement(space=gym.spaces.Box(
|
2021-03-17 08:18:15 +01:00
|
|
|
-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype),
|
|
|
|
used_for_compute_actions=False)
|
2021-04-11 18:20:04 +02:00
|
|
|
for key in self._dummy_batch.accessed_keys:
|
|
|
|
if key not in self.view_requirements:
|
|
|
|
self.view_requirements[key] = ViewRequirement()
|
|
|
|
self.view_requirements[key].used_for_compute_actions = True
|
|
|
|
self._dummy_batch = self._get_dummy_batch_from_view_requirements(
|
|
|
|
sample_batch_size)
|
2021-03-17 08:18:15 +01:00
|
|
|
self._dummy_batch.set_get_interceptor(None)
|
|
|
|
self.exploration.postprocess_trajectory(self, self._dummy_batch)
|
|
|
|
postprocessed_batch = self.postprocess_trajectory(self._dummy_batch)
|
2020-12-21 02:22:32 +01:00
|
|
|
seq_lens = None
|
2020-11-03 21:53:34 +01:00
|
|
|
if state_outs:
|
2020-12-01 08:21:45 +01:00
|
|
|
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
|
2020-11-03 21:53:34 +01:00
|
|
|
i = 0
|
2020-11-12 16:27:34 +01:00
|
|
|
while "state_in_{}".format(i) in postprocessed_batch:
|
|
|
|
postprocessed_batch["state_in_{}".format(i)] = \
|
|
|
|
postprocessed_batch["state_in_{}".format(i)][:B]
|
|
|
|
if "state_out_{}".format(i) in postprocessed_batch:
|
|
|
|
postprocessed_batch["state_out_{}".format(i)] = \
|
|
|
|
postprocessed_batch["state_out_{}".format(i)][:B]
|
2020-11-03 21:53:34 +01:00
|
|
|
i += 1
|
2020-11-12 16:27:34 +01:00
|
|
|
seq_len = sample_batch_size // B
|
2020-12-21 02:22:32 +01:00
|
|
|
seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)
|
2021-04-30 19:26:30 +02:00
|
|
|
postprocessed_batch["seq_lens"] = seq_lens
|
2021-03-17 08:18:15 +01:00
|
|
|
# Switch on lazy to-tensor conversion on `postprocessed_batch`.
|
2020-12-21 02:22:32 +01:00
|
|
|
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
2021-04-16 09:16:24 +02:00
|
|
|
# Calling loss, so set `is_training` to True.
|
|
|
|
train_batch.is_training = True
|
2020-12-21 02:22:32 +01:00
|
|
|
if seq_lens is not None:
|
2021-04-30 19:26:30 +02:00
|
|
|
train_batch["seq_lens"] = seq_lens
|
2020-11-12 16:27:34 +01:00
|
|
|
train_batch.count = self._dummy_batch.count
|
|
|
|
# Call the loss function, if it exists.
|
2020-11-03 21:53:34 +01:00
|
|
|
if self._loss is not None:
|
|
|
|
self._loss(self, self.model, self.dist_class, train_batch)
|
2020-11-12 16:27:34 +01:00
|
|
|
# Call the stats fn, if given.
|
|
|
|
if stats_fn is not None:
|
|
|
|
stats_fn(self, train_batch)
|
2020-11-03 21:53:34 +01:00
|
|
|
|
|
|
|
# Add new columns automatically to view-reqs.
|
2021-03-23 17:50:18 +01:00
|
|
|
if auto_remove_unneeded_view_reqs:
|
2020-11-03 21:53:34 +01:00
|
|
|
# Add those needed for postprocessing and training.
|
|
|
|
all_accessed_keys = train_batch.accessed_keys | \
|
2021-03-17 08:18:15 +01:00
|
|
|
self._dummy_batch.accessed_keys | \
|
|
|
|
self._dummy_batch.added_keys
|
2020-11-03 21:53:34 +01:00
|
|
|
for key in all_accessed_keys:
|
|
|
|
if key not in self.view_requirements:
|
|
|
|
self.view_requirements[key] = ViewRequirement()
|
|
|
|
if self._loss:
|
2021-02-11 18:58:46 +01:00
|
|
|
# Tag those only needed for post-processing (with some
|
|
|
|
# exceptions).
|
2021-03-17 08:18:15 +01:00
|
|
|
for key in self._dummy_batch.accessed_keys:
|
2020-11-12 16:27:34 +01:00
|
|
|
if key not in train_batch.accessed_keys and \
|
2020-12-01 08:21:45 +01:00
|
|
|
key in self.view_requirements and \
|
2021-02-11 18:58:46 +01:00
|
|
|
key not in self.model.view_requirements and \
|
|
|
|
key not in [
|
|
|
|
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
|
|
|
SampleBatch.UNROLL_ID, SampleBatch.DONES,
|
|
|
|
SampleBatch.REWARDS, SampleBatch.INFOS]:
|
2020-11-03 21:53:34 +01:00
|
|
|
self.view_requirements[key].used_for_training = False
|
|
|
|
# Remove those not needed at all (leave those that are needed
|
|
|
|
# by Sampler to properly execute sample collection).
|
2021-02-09 17:05:26 +01:00
|
|
|
# Also always leave DONES, REWARDS, INFOS, no matter what.
|
2020-11-03 21:53:34 +01:00
|
|
|
for key in list(self.view_requirements.keys()):
|
|
|
|
if key not in all_accessed_keys and key not in [
|
|
|
|
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
2020-11-12 16:27:34 +01:00
|
|
|
SampleBatch.UNROLL_ID, SampleBatch.DONES,
|
2021-02-09 17:05:26 +01:00
|
|
|
SampleBatch.REWARDS, SampleBatch.INFOS] and \
|
2020-12-30 20:32:21 -05:00
|
|
|
key not in self.model.view_requirements:
|
2020-11-12 16:27:34 +01:00
|
|
|
# If user deleted this key manually in postprocessing
|
|
|
|
# fn, warn about it and do not remove from
|
|
|
|
# view-requirements.
|
2021-03-17 08:18:15 +01:00
|
|
|
if key in self._dummy_batch.deleted_keys:
|
2020-11-12 16:27:34 +01:00
|
|
|
logger.warning(
|
|
|
|
"SampleBatch key '{}' was deleted manually in "
|
|
|
|
"postprocessing function! RLlib will "
|
|
|
|
"automatically remove non-used items from the "
|
|
|
|
"data stream. Remove the `del` from your "
|
|
|
|
"postprocessing function.".format(key))
|
|
|
|
else:
|
|
|
|
del self.view_requirements[key]
|
2020-11-03 21:53:34 +01:00
|
|
|
|
|
|
|
def _get_dummy_batch_from_view_requirements(
|
|
|
|
self, batch_size: int = 1) -> SampleBatch:
|
|
|
|
"""Creates a numpy dummy batch based on the Policy's view requirements.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch_size (int): The size of the batch to create.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, TensorType]: The dummy batch containing all zero values.
|
|
|
|
"""
|
|
|
|
ret = {}
|
|
|
|
for view_col, view_req in self.view_requirements.items():
|
|
|
|
if isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
2021-02-02 13:13:43 +01:00
|
|
|
_, shape = ModelCatalog.get_action_shape(
|
|
|
|
view_req.space, framework=self.config["framework"])
|
2020-11-03 21:53:34 +01:00
|
|
|
ret[view_col] = \
|
|
|
|
np.zeros((batch_size, ) + shape[1:], np.float32)
|
|
|
|
else:
|
2020-12-21 02:22:32 +01:00
|
|
|
# Range of indices on time-axis, e.g. "-50:-1".
|
|
|
|
if view_req.shift_from is not None:
|
|
|
|
ret[view_col] = np.zeros_like([[
|
|
|
|
view_req.space.sample()
|
|
|
|
for _ in range(view_req.shift_to -
|
|
|
|
view_req.shift_from + 1)
|
|
|
|
] for _ in range(batch_size)])
|
|
|
|
# Set of (probably non-consecutive) indices.
|
|
|
|
elif isinstance(view_req.shift, (list, tuple)):
|
|
|
|
ret[view_col] = np.zeros_like([[
|
|
|
|
view_req.space.sample()
|
|
|
|
for t in range(len(view_req.shift))
|
|
|
|
] for _ in range(batch_size)])
|
|
|
|
# Single shift int value.
|
2020-12-07 13:08:17 +01:00
|
|
|
else:
|
2020-12-21 02:22:32 +01:00
|
|
|
if isinstance(view_req.space, gym.spaces.Space):
|
|
|
|
ret[view_col] = np.zeros_like([
|
|
|
|
view_req.space.sample() for _ in range(batch_size)
|
|
|
|
])
|
|
|
|
else:
|
|
|
|
ret[view_col] = [
|
|
|
|
view_req.space for _ in range(batch_size)
|
|
|
|
]
|
|
|
|
|
|
|
|
# Due to different view requirements for the different columns,
|
|
|
|
# columns in the resulting batch may not all have the same batch size.
|
2021-04-27 10:44:54 +02:00
|
|
|
return SampleBatch(ret)
|
2020-11-03 21:53:34 +01:00
|
|
|
|
2020-12-30 20:32:21 -05:00
|
|
|
def _update_model_view_requirements_from_init_state(self):
|
2020-12-07 13:08:17 +01:00
|
|
|
"""Uses Model's (or this Policy's) init state to add needed ViewReqs.
|
2020-11-03 21:53:34 +01:00
|
|
|
|
|
|
|
Can be called from within a Policy to make sure RNNs automatically
|
|
|
|
update their internal state-related view requirements.
|
2020-12-30 20:32:21 -05:00
|
|
|
Changes the `self.view_requirements` dict.
|
2020-11-03 21:53:34 +01:00
|
|
|
"""
|
2020-12-07 13:08:17 +01:00
|
|
|
self._model_init_state_automatically_added = True
|
|
|
|
model = getattr(self, "model", None)
|
2021-04-30 19:26:30 +02:00
|
|
|
|
2020-12-07 13:08:17 +01:00
|
|
|
obj = model or self
|
2021-04-27 10:44:54 +02:00
|
|
|
if model and not hasattr(model, "view_requirements"):
|
|
|
|
model.view_requirements = {
|
|
|
|
SampleBatch.OBS: ViewRequirement(space=self.observation_space)
|
|
|
|
}
|
|
|
|
view_reqs = obj.view_requirements
|
2020-11-03 21:53:34 +01:00
|
|
|
# Add state-ins to this model's view.
|
2021-04-27 10:44:54 +02:00
|
|
|
init_state = []
|
|
|
|
if hasattr(obj, "get_initial_state") and callable(
|
|
|
|
obj.get_initial_state):
|
|
|
|
init_state = obj.get_initial_state()
|
|
|
|
else:
|
2021-04-30 19:26:30 +02:00
|
|
|
# Add this functionality automatically for new native model API.
|
|
|
|
if tf and isinstance(model, tf.keras.Model) and \
|
|
|
|
"state_in_0" not in view_reqs:
|
|
|
|
obj.get_initial_state = lambda: [
|
|
|
|
np.zeros_like(view_req.space.sample())
|
|
|
|
for k, view_req in model.view_requirements.items()
|
|
|
|
if k.startswith("state_in_")]
|
|
|
|
else:
|
|
|
|
obj.get_initial_state = lambda: []
|
|
|
|
if "state_in_0" in view_reqs:
|
|
|
|
self.is_recurrent = lambda: True
|
2021-05-20 09:28:09 +02:00
|
|
|
|
2021-07-25 19:25:07 +01:00
|
|
|
# Make sure auto-generated init-state view requirements get added
|
|
|
|
# to both Policy and Model, no matter what.
|
|
|
|
view_reqs = [view_reqs] + ([self.view_requirements] if hasattr(
|
|
|
|
self, "view_requirements") else [])
|
|
|
|
|
2021-04-27 10:44:54 +02:00
|
|
|
for i, state in enumerate(init_state):
|
2021-05-20 09:28:09 +02:00
|
|
|
# Allow `state` to be either a Space (use zeros as initial values)
|
|
|
|
# or any value (e.g. a dict or a non-zero tensor).
|
|
|
|
fw = np if isinstance(state, np.ndarray) else torch if \
|
|
|
|
torch and torch.is_tensor(state) else None
|
|
|
|
if fw:
|
|
|
|
space = Box(-1.0, 1.0, shape=state.shape) if \
|
|
|
|
fw.all(state == 0.0) else state
|
|
|
|
else:
|
|
|
|
space = state
|
2021-07-25 19:25:07 +01:00
|
|
|
for vr in view_reqs:
|
|
|
|
vr["state_in_{}".format(i)] = ViewRequirement(
|
|
|
|
"state_out_{}".format(i),
|
|
|
|
shift=-1,
|
|
|
|
used_for_compute_actions=True,
|
|
|
|
batch_repeat_value=self.config.get("model", {}).get(
|
|
|
|
"max_seq_len", 1),
|
|
|
|
space=space)
|
|
|
|
vr["state_out_{}".format(i)] = ViewRequirement(
|
|
|
|
space=space, used_for_training=True)
|
2020-11-03 21:53:34 +01:00
|
|
|
|
2021-06-15 13:08:43 +02:00
|
|
|
# TODO: (sven) Deprecate this in favor of `save()`.
|
|
|
|
def export_checkpoint(self, export_dir: str) -> None:
|
|
|
|
"""Export Policy checkpoint to local directory.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
export_dir (str): Local writable directory.
|
|
|
|
"""
|
|
|
|
deprecation_warning("export_checkpoint", "save")
|
|
|
|
raise NotImplementedError
|