2020-08-06 10:54:20 +02:00
|
|
|
import gym
|
2020-12-21 02:22:32 +01:00
|
|
|
import numpy as np
|
2020-08-15 15:09:00 +02:00
|
|
|
from typing import List, Optional, Union
|
2020-08-06 10:54:20 +02:00
|
|
|
|
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
|
|
|
|
|
|
class ViewRequirement:
|
|
|
|
"""Single view requirement (for one column in an SampleBatch/input_dict).
|
|
|
|
|
|
|
|
Note: This is an experimental class used only if
|
|
|
|
`_use_trajectory_view_api` in the config is set to True.
|
|
|
|
|
|
|
|
Policies and ModelV2s return a Dict[str, ViewRequirement] upon calling
|
|
|
|
their `[train|inference]_view_requirements()` methods, where the str key
|
|
|
|
represents the column name (C) under which the view is available in the
|
|
|
|
input_dict/SampleBatch and ViewRequirement specifies the actual underlying
|
|
|
|
column names (in the original data buffer), timestep shifts, and other
|
|
|
|
options to build the view.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> # The default ViewRequirement for a Model is:
|
2020-12-30 20:32:21 -05:00
|
|
|
>>> req = [ModelV2].view_requirements
|
2020-08-06 10:54:20 +02:00
|
|
|
>>> print(req)
|
|
|
|
{"obs": ViewRequirement(shift=0)}
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
data_col: Optional[str] = None,
|
|
|
|
space: gym.Space = None,
|
2020-12-21 02:22:32 +01:00
|
|
|
shift: Union[int, str, List[int]] = 0,
|
2020-12-07 13:08:17 +01:00
|
|
|
index: Optional[int] = None,
|
2020-12-21 02:22:32 +01:00
|
|
|
batch_repeat_value: int = 1,
|
2020-10-01 16:57:10 +02:00
|
|
|
used_for_training: bool = True):
|
2020-08-06 10:54:20 +02:00
|
|
|
"""Initializes a ViewRequirement object.
|
|
|
|
|
|
|
|
Args:
|
2020-12-07 13:08:17 +01:00
|
|
|
data_col (Optional[str]): The data column name from the SampleBatch
|
|
|
|
(str key). If None, use the dict key under which this
|
|
|
|
ViewRequirement resides.
|
2020-08-06 10:54:20 +02:00
|
|
|
space (gym.Space): The gym Space used in case we need to pad data
|
|
|
|
in inaccessible areas of the trajectory (t<0 or t>H).
|
|
|
|
Default: Simple box space, e.g. rewards.
|
2020-12-07 13:08:17 +01:00
|
|
|
shift (Union[int, str, List[int]]): Single shift value or
|
2020-12-01 08:21:45 +01:00
|
|
|
list of relative positions to use (relative to the underlying
|
|
|
|
`data_col`).
|
2020-08-06 10:54:20 +02:00
|
|
|
Example: For a view column "prev_actions", you can set
|
2020-12-07 13:08:17 +01:00
|
|
|
`data_col="actions"` and `shift=-1`.
|
2020-08-06 10:54:20 +02:00
|
|
|
Example: For a view column "obs" in an Atari framestacking
|
|
|
|
fashion, you can set `data_col="obs"` and
|
2020-12-07 13:08:17 +01:00
|
|
|
`shift=[-3, -2, -1, 0]`.
|
|
|
|
Example: For the obs input to an attention net, you can specify
|
|
|
|
a range via a str: `shift="-100:0"`, which will pass in
|
|
|
|
the past 100 observations plus the current one.
|
|
|
|
index (Optional[int]): An optional absolute position arg,
|
|
|
|
used e.g. for the location of a requested inference dict within
|
|
|
|
the trajectory. Negative values refer to counting from the end
|
|
|
|
of a trajectory.
|
2020-10-01 16:57:10 +02:00
|
|
|
used_for_training (bool): Whether the data will be used for
|
|
|
|
training. If False, the column will not be copied into the
|
|
|
|
final train batch.
|
2020-08-06 10:54:20 +02:00
|
|
|
"""
|
|
|
|
self.data_col = data_col
|
2020-12-07 13:08:17 +01:00
|
|
|
self.space = space if space is not None else gym.spaces.Box(
|
2020-08-06 10:54:20 +02:00
|
|
|
float("-inf"), float("inf"), shape=())
|
2020-12-07 13:08:17 +01:00
|
|
|
|
2020-12-21 02:22:32 +01:00
|
|
|
self.shift = shift
|
|
|
|
if isinstance(self.shift, (list, tuple)):
|
|
|
|
self.shift = np.array(self.shift)
|
|
|
|
|
|
|
|
# Special case: Providing a (probably larger) range of indices, e.g.
|
|
|
|
# "-100:0" (past 100 timesteps plus current one).
|
|
|
|
self.shift_from = self.shift_to = None
|
|
|
|
if isinstance(self.shift, str):
|
|
|
|
f, t = self.shift.split(":")
|
|
|
|
self.shift_from = int(f)
|
|
|
|
self.shift_to = int(t)
|
|
|
|
|
2020-12-07 13:08:17 +01:00
|
|
|
self.index = index
|
2020-12-21 02:22:32 +01:00
|
|
|
self.batch_repeat_value = batch_repeat_value
|
2020-12-07 13:08:17 +01:00
|
|
|
|
2020-10-01 16:57:10 +02:00
|
|
|
self.used_for_training = used_for_training
|