ray/rllib/policy/view_requirement.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

133 lines
5.1 KiB
Python
Raw Normal View History

import gym
from typing import Dict, List, Optional, Union
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.serialization import (
gym_space_to_dict,
gym_space_from_dict,
)
torch, _ = try_import_torch()
@PublicAPI
class ViewRequirement:
"""Single view requirement (for one column in an SampleBatch/input_dict).
Policies and ModelV2s return a Dict[str, ViewRequirement] upon calling
their `[train|inference]_view_requirements()` methods, where the str key
represents the column name (C) under which the view is available in the
input_dict/SampleBatch and ViewRequirement specifies the actual underlying
column names (in the original data buffer), timestep shifts, and other
options to build the view.
Examples:
>>> from ray.rllib.models.modelv2 import ModelV2
>>> # The default ViewRequirement for a Model is:
>>> req = ModelV2(...).view_requirements # doctest: +SKIP
>>> print(req) # doctest: +SKIP
{"obs": ViewRequirement(shift=0)}
"""
def __init__(
self,
data_col: Optional[str] = None,
space: gym.Space = None,
2020-12-21 02:22:32 +01:00
shift: Union[int, str, List[int]] = 0,
index: Optional[int] = None,
2020-12-21 02:22:32 +01:00
batch_repeat_value: int = 1,
used_for_compute_actions: bool = True,
used_for_training: bool = True,
):
"""Initializes a ViewRequirement object.
Args:
data_col (Optional[str]): The data column name from the SampleBatch
(str key). If None, use the dict key under which this
ViewRequirement resides.
space (gym.Space): The gym Space used in case we need to pad data
in inaccessible areas of the trajectory (t<0 or t>H).
Default: Simple box space, e.g. rewards.
shift (Union[int, str, List[int]]): Single shift value or
list of relative positions to use (relative to the underlying
`data_col`).
Example: For a view column "prev_actions", you can set
`data_col="actions"` and `shift=-1`.
Example: For a view column "obs" in an Atari framestacking
fashion, you can set `data_col="obs"` and
`shift=[-3, -2, -1, 0]`.
Example: For the obs input to an attention net, you can specify
a range via a str: `shift="-100:0"`, which will pass in
the past 100 observations plus the current one.
index (Optional[int]): An optional absolute position arg,
used e.g. for the location of a requested inference dict within
the trajectory. Negative values refer to counting from the end
of a trajectory.
used_for_compute_actions: Whether the data will be used for
creating input_dicts for `Policy.compute_actions()` calls (or
`Policy.compute_actions_from_input_dict()`).
used_for_training: Whether the data will be used for
training. If False, the column will not be copied into the
final train batch.
"""
self.data_col = data_col
self.space = (
space
if space is not None
else gym.spaces.Box(float("-inf"), float("inf"), shape=())
)
2020-12-21 02:22:32 +01:00
self.shift = shift
# Special case: Providing a (probably larger) range of indices, e.g.
# "-100:0" (past 100 timesteps plus current one).
self.shift_from = self.shift_to = None
if isinstance(self.shift, str):
f, t = self.shift.split(":")
self.shift_from = int(f)
self.shift_to = int(t)
self.index = index
2020-12-21 02:22:32 +01:00
self.batch_repeat_value = batch_repeat_value
self.used_for_compute_actions = used_for_compute_actions
self.used_for_training = used_for_training
def __str__(self):
"""For easier inspection of view requirements."""
return "|".join(
[
str(v)
for v in [
self.data_col,
self.space,
self.shift,
self.shift_from,
self.shift_to,
self.index,
self.batch_repeat_value,
self.used_for_training,
self.used_for_compute_actions,
]
]
)
def to_dict(self) -> Dict:
"""Return a dict for this ViewRequirement that can be JSON serialized."""
return {
"data_col": self.data_col,
"space": gym_space_to_dict(self.space),
"shift": self.shift,
"index": self.index,
"batch_repeat_value": self.batch_repeat_value,
"used_for_training": self.used_for_training,
"used_for_compute_actions": self.used_for_compute_actions,
}
@classmethod
def from_dict(cls, d: Dict):
"""Construct a ViewRequirement instance from JSON deserialized dict."""
d["space"] = gym_space_from_dict(d["space"])
return cls(**d)