mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] No Preprocessors (part 2). (#18468)
This commit is contained in:
parent
a2a077b874
commit
61a1274619
25 changed files with 657 additions and 308 deletions
20
rllib/BUILD
20
rllib/BUILD
|
@ -1460,7 +1460,7 @@ py_test(
|
|||
py_test(
|
||||
name = "test_preprocessors",
|
||||
tags = ["team:ml", "models"],
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["models/tests/test_preprocessors.py"]
|
||||
)
|
||||
|
||||
|
@ -2659,6 +2659,24 @@ py_test(
|
|||
srcs = ["examples/pettingzoo_env.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/preprocessing_disabled_tf",
|
||||
main = "examples/preprocessing_disabled.py",
|
||||
tags = ["team:ml", "examples", "examples_P"],
|
||||
size = "medium",
|
||||
srcs = ["examples/preprocessing_disabled.py"],
|
||||
args = ["--stop-iters=2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/preprocessing_disabled_torch",
|
||||
main = "examples/preprocessing_disabled.py",
|
||||
tags = ["team:ml", "examples", "examples_P"],
|
||||
size = "medium",
|
||||
srcs = ["examples/preprocessing_disabled.py"],
|
||||
args = ["--framework=torch", "--stop-iters=2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/remote_envs_with_inference_done_on_main_node_tf",
|
||||
main = "examples/remote_envs_with_inference_done_on_main_node.py",
|
||||
|
|
|
@ -30,12 +30,13 @@ from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override, \
|
|||
from ray.rllib.utils.debug import update_global_seed_if_necessary
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR
|
||||
from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.multi_agent import check_multi_agent
|
||||
from ray.rllib.utils.spaces import space_utils
|
||||
from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \
|
||||
PartialTrainerConfigDict, PolicyID, ResultDict, TrainerConfigDict
|
||||
PartialTrainerConfigDict, PolicyID, ResultDict, TensorStructType, \
|
||||
TrainerConfigDict
|
||||
from ray.tune.logger import Logger, UnifiedLogger
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.resources import Resources
|
||||
|
@ -113,11 +114,6 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
"model": MODEL_DEFAULTS,
|
||||
# Arguments to pass to the policy optimizer. These vary by optimizer.
|
||||
"optimizer": {},
|
||||
# Experimental flag, indicating that TFPolicy will handle more than one
|
||||
# loss/optimizer. Set this to True, if you would like to return more than
|
||||
# one loss term from your `loss_fn` and an equal number of optimizers
|
||||
# from your `optimizer_fn`.
|
||||
"_tf_policy_handles_more_than_one_loss": False,
|
||||
|
||||
# === Environment Settings ===
|
||||
# Number of steps after which the episode is forced to terminate. Defaults
|
||||
|
@ -483,6 +479,20 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# Default value None allows overwriting with nested dicts
|
||||
"logger_config": None,
|
||||
|
||||
# === API deprecations/simplifications/changes ===
|
||||
# Experimental flag.
|
||||
# If True, TFPolicy will handle more than one loss/optimizer.
|
||||
# Set this to True, if you would like to return more than
|
||||
# one loss term from your `loss_fn` and an equal number of optimizers
|
||||
# from your `optimizer_fn`.
|
||||
# In the future, the default for this will be True.
|
||||
"_tf_policy_handles_more_than_one_loss": False,
|
||||
# Experimental flag.
|
||||
# If True, no (observation) preprocessor will be created and
|
||||
# observations will arrive in model as they are returned by the env.
|
||||
# In the future, the default for this will be True.
|
||||
"_disable_preprocessor_api": False,
|
||||
|
||||
# === Deprecated keys ===
|
||||
# Uses the sync samples optimizer instead of the multi-gpu one. This is
|
||||
# usually slower, but you might want to try it if you run into issues with
|
||||
|
@ -1128,8 +1138,8 @@ class Trainer(Trainable):
|
|||
tuple: The full output of policy.compute_actions() if
|
||||
full_fetch=True or we have an RNN-based Policy.
|
||||
"""
|
||||
# Preprocess obs and states
|
||||
stateDefined = state is not None
|
||||
# Preprocess obs and states.
|
||||
state_defined = state is not None
|
||||
policy = self.get_policy(policy_id)
|
||||
filtered_obs, filtered_state = [], []
|
||||
for agent_id, ob in observations.items():
|
||||
|
@ -1174,7 +1184,7 @@ class Trainer(Trainable):
|
|||
unbatched_states[agent_id] = [s[idx] for s in states]
|
||||
|
||||
# Return only actions or full tuple
|
||||
if stateDefined or full_fetch:
|
||||
if state_defined or full_fetch:
|
||||
return actions, unbatched_states, infos
|
||||
else:
|
||||
return actions
|
||||
|
@ -1529,8 +1539,8 @@ class Trainer(Trainable):
|
|||
# Check model config.
|
||||
# If no preprocessing, propagate into model's config as well
|
||||
# (so model will know, whether inputs are preprocessed or not).
|
||||
if config["preprocessor_pref"] is None:
|
||||
model_config["_no_preprocessor"] = True
|
||||
if config["_disable_preprocessor_api"] is True:
|
||||
model_config["_disable_preprocessor_api"] = True
|
||||
|
||||
# Prev_a/r settings.
|
||||
prev_a_r = model_config.get("lstm_use_prev_action_reward",
|
||||
|
|
|
@ -3,6 +3,7 @@ from gym.spaces import Space
|
|||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
import tree # pip install dm_tree
|
||||
from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
|
@ -14,6 +15,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
|
||||
from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \
|
||||
TensorType, ViewRequirementsDict
|
||||
from ray.util.debug import log_once
|
||||
|
@ -47,7 +49,8 @@ class _AgentCollector:
|
|||
|
||||
_next_unroll_id = 0 # disambiguates unrolls within a single episode
|
||||
|
||||
def __init__(self, view_reqs):
|
||||
def __init__(self, view_reqs, policy):
|
||||
self.policy = policy
|
||||
# Determine the size of the buffer we need for data before the actual
|
||||
# episode starts. This is used for 0-buffering of e.g. prev-actions,
|
||||
# or internal state inputs.
|
||||
|
@ -57,10 +60,28 @@ class _AgentCollector:
|
|||
(1
|
||||
if vr.data_col == SampleBatch.OBS or k == SampleBatch.OBS else 0)
|
||||
for k, vr in view_reqs.items())
|
||||
# The actual data buffers (lists holding each timestep's data).
|
||||
self.buffers: Dict[str, List] = {}
|
||||
|
||||
# The actual data buffers. Keys are column names, values are lists
|
||||
# that contain the sub-components (e.g. for complex obs spaces) with
|
||||
# each sub-component holding a list of per-timestep tensors.
|
||||
# E.g.: obs-space = Dict(a=Discrete(2), b=Box((2,)))
|
||||
# buffers["obs"] = [
|
||||
# [0, 1], # <- 1st sub-component of observation
|
||||
# [np.array([.2, .3]), np.array([.0, -.2])] # <- 2nd sub-component
|
||||
# ]
|
||||
# NOTE: infos and state_out_... are not flattened due to them often
|
||||
# using custom dict values whose structure may vary from timestep to
|
||||
# timestep.
|
||||
self.buffers: Dict[str, List[List[TensorType]]] = {}
|
||||
# Maps column names to an example data item, which may be deeply
|
||||
# nested. These are used such that we'll know how to unflatten
|
||||
# the flattened data inside self.buffers when building the
|
||||
# SampleBatch.
|
||||
self.buffer_structs: Dict[str, Any] = {}
|
||||
# The episode ID for the agent for which we collect data.
|
||||
self.episode_id = None
|
||||
# The unroll ID, unique across all rollouts (within a RolloutWorker).
|
||||
self.unroll_id = None
|
||||
# The simple timestep count for this agent. Gets increased by one
|
||||
# each time a (non-initial!) observation is added.
|
||||
self.agent_steps = 0
|
||||
|
@ -80,6 +101,13 @@ class _AgentCollector:
|
|||
init_obs (TensorType): The initial observation tensor (after
|
||||
`env.reset()`).
|
||||
"""
|
||||
# Store episode ID + unroll ID, which will be constant throughout this
|
||||
# AgentCollector's lifecycle.
|
||||
self.episode_id = episode_id
|
||||
if self.unroll_id is None:
|
||||
self.unroll_id = _AgentCollector._next_unroll_id
|
||||
_AgentCollector._next_unroll_id += 1
|
||||
|
||||
if SampleBatch.OBS not in self.buffers:
|
||||
self._build_buffers(
|
||||
single_row={
|
||||
|
@ -87,12 +115,19 @@ class _AgentCollector:
|
|||
SampleBatch.AGENT_INDEX: agent_index,
|
||||
SampleBatch.ENV_ID: env_id,
|
||||
SampleBatch.T: t,
|
||||
SampleBatch.EPS_ID: self.episode_id,
|
||||
SampleBatch.UNROLL_ID: self.unroll_id,
|
||||
})
|
||||
self.buffers[SampleBatch.OBS].append(init_obs)
|
||||
self.episode_id = episode_id
|
||||
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
|
||||
self.buffers[SampleBatch.ENV_ID].append(env_id)
|
||||
self.buffers[SampleBatch.T].append(t)
|
||||
|
||||
# Append data to existing buffers.
|
||||
flattened = tree.flatten(init_obs)
|
||||
for i, sub_obs in enumerate(flattened):
|
||||
self.buffers[SampleBatch.OBS][i].append(sub_obs)
|
||||
self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index)
|
||||
self.buffers[SampleBatch.ENV_ID][0].append(env_id)
|
||||
self.buffers[SampleBatch.T][0].append(t)
|
||||
self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
|
||||
self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
|
||||
|
||||
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
|
||||
None:
|
||||
|
@ -103,20 +138,40 @@ class _AgentCollector:
|
|||
row) to be added to buffer. Must contain keys:
|
||||
SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS.
|
||||
"""
|
||||
if self.unroll_id is None:
|
||||
self.unroll_id = _AgentCollector._next_unroll_id
|
||||
_AgentCollector._next_unroll_id += 1
|
||||
|
||||
# Next obs -> obs.
|
||||
assert SampleBatch.OBS not in values
|
||||
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
|
||||
del values[SampleBatch.NEXT_OBS]
|
||||
# Make sure EPS_ID stays the same for this agent. Usually, it should
|
||||
# not be part of `values` anyways.
|
||||
|
||||
# Make sure EPS_ID/UNROLL_ID stay the same for this agent.
|
||||
if SampleBatch.EPS_ID in values:
|
||||
assert values[SampleBatch.EPS_ID] == self.episode_id
|
||||
del values[SampleBatch.EPS_ID]
|
||||
self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
|
||||
if SampleBatch.UNROLL_ID in values:
|
||||
assert values[SampleBatch.UNROLL_ID] == self.unroll_id
|
||||
del values[SampleBatch.UNROLL_ID]
|
||||
self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
|
||||
|
||||
for k, v in values.items():
|
||||
if k not in self.buffers:
|
||||
self._build_buffers(single_row=values)
|
||||
self.buffers[k].append(v)
|
||||
# Do not flatten infos, state_out_ and actions.
|
||||
# Infos/state-outs may be structs that change from timestep to
|
||||
# timestep. Actions - on the other hand - are already flattened
|
||||
# in the sampler.
|
||||
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS
|
||||
] or k.startswith("state_out_"):
|
||||
self.buffers[k][0].append(v)
|
||||
# Flatten all other columns.
|
||||
else:
|
||||
flattened = tree.flatten(v)
|
||||
for i, sub_list in enumerate(self.buffers[k]):
|
||||
sub_list.append(flattened[i])
|
||||
self.agent_steps += 1
|
||||
|
||||
def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
|
||||
|
@ -157,7 +212,9 @@ class _AgentCollector:
|
|||
# Keep an np-array cache so we don't have to regenerate the
|
||||
# np-array for different view_cols using to the same data_col.
|
||||
if data_col not in np_data:
|
||||
np_data[data_col] = to_float_np_array(self.buffers[data_col])
|
||||
np_data[data_col] = [
|
||||
to_float_np_array(d) for d in self.buffers[data_col]
|
||||
]
|
||||
|
||||
# Range of indices on time-axis, e.g. "-50:-1". Together with
|
||||
# the `batch_repeat_value`, this determines the data produced.
|
||||
|
@ -171,42 +228,50 @@ class _AgentCollector:
|
|||
# every n timesteps.
|
||||
if view_req.batch_repeat_value > 1:
|
||||
count = int(
|
||||
math.ceil((len(np_data[data_col]) - self.shift_before)
|
||||
/ view_req.batch_repeat_value))
|
||||
data = np.asarray([
|
||||
np_data[data_col][self.shift_before +
|
||||
(i * view_req.batch_repeat_value) +
|
||||
view_req.shift_from +
|
||||
obs_shift:self.shift_before +
|
||||
(i * view_req.batch_repeat_value) +
|
||||
view_req.shift_to + 1 + obs_shift]
|
||||
for i in range(count)
|
||||
])
|
||||
math.ceil(
|
||||
(len(np_data[data_col][0]) - self.shift_before) /
|
||||
view_req.batch_repeat_value))
|
||||
data = [
|
||||
np.asarray([
|
||||
d[self.shift_before +
|
||||
(i * view_req.batch_repeat_value) +
|
||||
view_req.shift_from +
|
||||
obs_shift:self.shift_before +
|
||||
(i * view_req.batch_repeat_value) +
|
||||
view_req.shift_to + 1 + obs_shift]
|
||||
for i in range(count)
|
||||
]) for d in np_data[data_col]
|
||||
]
|
||||
# Batch repeat value = 1: Repeat the shift_from/to range at
|
||||
# each timestep.
|
||||
else:
|
||||
d = np_data[data_col]
|
||||
d0 = np_data[data_col][0]
|
||||
shift_win = view_req.shift_to - view_req.shift_from + 1
|
||||
data_size = d.itemsize * int(np.product(d.shape[1:]))
|
||||
data_size = d0.itemsize * int(np.product(d0.shape[1:]))
|
||||
strides = [
|
||||
d.itemsize * int(np.product(d.shape[i + 1:]))
|
||||
for i in range(1, len(d.shape))
|
||||
d0.itemsize * int(np.product(d0.shape[i + 1:]))
|
||||
for i in range(1, len(d0.shape))
|
||||
]
|
||||
start = self.shift_before - shift_win + 1 + obs_shift + \
|
||||
view_req.shift_to
|
||||
data = np.lib.stride_tricks.as_strided(
|
||||
d[start:start + self.agent_steps],
|
||||
[self.agent_steps, shift_win
|
||||
] + [d.shape[i] for i in range(1, len(d.shape))],
|
||||
[data_size, data_size] + strides)
|
||||
data = [
|
||||
np.lib.stride_tricks.as_strided(
|
||||
d[start:start + self.agent_steps],
|
||||
[self.agent_steps, shift_win
|
||||
] + [d.shape[i] for i in range(1, len(d.shape))],
|
||||
[data_size, data_size] + strides)
|
||||
for d in np_data[data_col]
|
||||
]
|
||||
# Set of (probably non-consecutive) indices.
|
||||
# Example:
|
||||
# shift=[-3, 0]
|
||||
# buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
||||
# resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...]
|
||||
elif isinstance(view_req.shift, np.ndarray):
|
||||
data = np_data[data_col][self.shift_before + obs_shift +
|
||||
view_req.shift]
|
||||
data = [
|
||||
d[self.shift_before + obs_shift + view_req.shift]
|
||||
for d in np_data[data_col]
|
||||
]
|
||||
# Single shift int value. Use the trajectory as-is, and if
|
||||
# `shift` != 0: shifted by that value.
|
||||
else:
|
||||
|
@ -215,58 +280,77 @@ class _AgentCollector:
|
|||
# Batch repeat (only provide a value every n timesteps).
|
||||
if view_req.batch_repeat_value > 1:
|
||||
count = int(
|
||||
math.ceil((len(np_data[data_col]) - self.shift_before)
|
||||
/ view_req.batch_repeat_value))
|
||||
data = np.asarray([
|
||||
np_data[data_col][self.shift_before + (
|
||||
i * view_req.batch_repeat_value) + shift]
|
||||
for i in range(count)
|
||||
])
|
||||
math.ceil(
|
||||
(len(np_data[data_col][0]) - self.shift_before) /
|
||||
view_req.batch_repeat_value))
|
||||
data = [
|
||||
np.asarray([
|
||||
d[self.shift_before +
|
||||
(i * view_req.batch_repeat_value) + shift]
|
||||
for i in range(count)
|
||||
]) for d in np_data[data_col]
|
||||
]
|
||||
# Shift is exactly 0: Use trajectory as is.
|
||||
elif shift == 0:
|
||||
data = np_data[data_col][self.shift_before:]
|
||||
data = [d[self.shift_before:] for d in np_data[data_col]]
|
||||
# Shift is positive: We still need to 0-pad at the end.
|
||||
elif shift > 0:
|
||||
data = to_float_np_array(
|
||||
self.buffers[data_col][self.shift_before + shift:] + [
|
||||
np.zeros(
|
||||
shape=view_req.space.shape,
|
||||
dtype=view_req.space.dtype)
|
||||
for _ in range(shift)
|
||||
])
|
||||
data = [
|
||||
to_float_np_array(
|
||||
np.concatenate([
|
||||
d[self.shift_before + shift:], [
|
||||
np.zeros(
|
||||
shape=view_req.space.shape,
|
||||
dtype=view_req.space.dtype)
|
||||
for _ in range(shift)
|
||||
]
|
||||
])) for d in np_data[data_col]
|
||||
]
|
||||
# Shift is negative: Shift into the already existing and
|
||||
# 0-padded "before" area of our buffers.
|
||||
else:
|
||||
data = np_data[data_col][self.shift_before + shift:shift]
|
||||
data = [
|
||||
d[self.shift_before + shift:shift]
|
||||
for d in np_data[data_col]
|
||||
]
|
||||
|
||||
if len(data) > 0:
|
||||
batch_data[view_col] = data
|
||||
if data_col not in self.buffer_structs:
|
||||
batch_data[view_col] = data[0]
|
||||
else:
|
||||
batch_data[view_col] = tree.unflatten_as(
|
||||
self.buffer_structs[data_col], data)
|
||||
|
||||
# Due to possible batch-repeats > 1, columns in the resulting batch
|
||||
# may not all have the same batch size.
|
||||
batch = SampleBatch(batch_data)
|
||||
|
||||
# Add EPS_ID and UNROLL_ID to batch.
|
||||
batch[SampleBatch.EPS_ID] = np.repeat(self.episode_id, batch.count)
|
||||
if SampleBatch.UNROLL_ID not in batch:
|
||||
# TODO: (sven) Once we have the additional
|
||||
# model.preprocess_train_batch in place (attention net PR), we
|
||||
# should not even need UNROLL_ID anymore:
|
||||
# Add "if SampleBatch.UNROLL_ID in view_requirements:" here.
|
||||
batch[SampleBatch.UNROLL_ID] = np.repeat(
|
||||
_AgentCollector._next_unroll_id, batch.count)
|
||||
_AgentCollector._next_unroll_id += 1
|
||||
# Adjust the seq-lens array depending on the incoming agent sequences.
|
||||
if self.policy.is_recurrent():
|
||||
seq_lens = []
|
||||
max_seq_len = self.policy.config["model"]["max_seq_len"]
|
||||
count = batch.count
|
||||
while count > 0:
|
||||
seq_lens.append(min(count, max_seq_len))
|
||||
count -= max_seq_len
|
||||
batch["seq_lens"] = np.array(seq_lens)
|
||||
batch.max_seq_len = max_seq_len
|
||||
|
||||
# This trajectory is continuing -> Copy data at the end (in the size of
|
||||
# self.shift_before) to the beginning of buffers and erase everything
|
||||
# else.
|
||||
if not self.buffers[SampleBatch.DONES][-1]:
|
||||
if not self.buffers[SampleBatch.DONES][0][-1]:
|
||||
# Copy data to beginning of buffer and cut lists.
|
||||
if self.shift_before > 0:
|
||||
for k, data in self.buffers.items():
|
||||
self.buffers[k] = data[-self.shift_before:]
|
||||
# Loop through
|
||||
for i in range(len(data)):
|
||||
self.buffers[k][i] = data[i][-self.shift_before:]
|
||||
self.agent_steps = 0
|
||||
|
||||
# Reset our unroll_id.
|
||||
self.unroll_id = None
|
||||
|
||||
return batch
|
||||
|
||||
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
|
||||
|
@ -279,12 +363,25 @@ class _AgentCollector:
|
|||
for col, data in single_row.items():
|
||||
if col in self.buffers:
|
||||
continue
|
||||
|
||||
shift = self.shift_before - (1 if col in [
|
||||
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
||||
SampleBatch.ENV_ID, SampleBatch.T
|
||||
SampleBatch.ENV_ID, SampleBatch.T, SampleBatch.UNROLL_ID
|
||||
] else 0)
|
||||
# Python primitive, tensor, or dict (e.g. INFOs).
|
||||
self.buffers[col] = [data for _ in range(shift)]
|
||||
|
||||
# Store all data as flattened lists, except INFOS and state-out
|
||||
# lists. These are monolithic items (infos is a dict that
|
||||
# should not be further split, same for state-out items, which
|
||||
# could be custom dicts as well).
|
||||
if col in [SampleBatch.INFOS, SampleBatch.ACTIONS
|
||||
] or col.startswith("state_out_"):
|
||||
self.buffers[col] = [[data for _ in range(shift)]]
|
||||
else:
|
||||
self.buffers[col] = [[v for _ in range(shift)]
|
||||
for v in tree.flatten(data)]
|
||||
# Store an example data struct so we know, how to unflatten
|
||||
# each data col.
|
||||
self.buffer_structs[col] = data
|
||||
|
||||
|
||||
class _PolicyCollector:
|
||||
|
@ -302,15 +399,13 @@ class _PolicyCollector:
|
|||
policy (Policy): The policy object.
|
||||
"""
|
||||
|
||||
self.buffers: Dict[str, List] = collections.defaultdict(list)
|
||||
self.batches = []
|
||||
self.policy = policy
|
||||
# The total timestep count for all agents that use this policy.
|
||||
# NOTE: This is not an env-step count (across n agents). AgentA and
|
||||
# agentB, both using this policy, acting in the same episode and both
|
||||
# doing n steps would increase the count by 2*n.
|
||||
self.agent_steps = 0
|
||||
# Seq-lens list of already added agent batches.
|
||||
self.seq_lens = [] if policy.is_recurrent() else None
|
||||
|
||||
def add_postprocessed_batch_for_training(
|
||||
self, batch: SampleBatch,
|
||||
|
@ -325,22 +420,13 @@ class _PolicyCollector:
|
|||
view-column needs to be copied at all (not needed for
|
||||
training).
|
||||
"""
|
||||
for view_col, data in batch.items():
|
||||
# 1) If col is not in view_requirements, we must have a direct
|
||||
# child of the base Policy that doesn't do auto-view req creation.
|
||||
# 2) Col is in view-reqs and needed for training.
|
||||
view_req = view_requirements.get(view_col)
|
||||
if view_req is None or view_req.used_for_training:
|
||||
self.buffers[view_col].extend(data)
|
||||
# Add the agent's trajectory length to our count.
|
||||
self.agent_steps += batch.count
|
||||
# Adjust the seq-lens array depending on the incoming agent sequences.
|
||||
if self.seq_lens is not None:
|
||||
max_seq_len = self.policy.config["model"]["max_seq_len"]
|
||||
count = batch.count
|
||||
while count > 0:
|
||||
self.seq_lens.append(min(count, max_seq_len))
|
||||
count -= max_seq_len
|
||||
# And remove columns not needed for training.
|
||||
for view_col, view_req in view_requirements.items():
|
||||
if view_col in batch and not view_req.used_for_training:
|
||||
del batch[view_col]
|
||||
self.batches.append(batch)
|
||||
|
||||
def build(self):
|
||||
"""Builds a SampleBatch for this policy from the collected data.
|
||||
|
@ -352,13 +438,11 @@ class _PolicyCollector:
|
|||
this policy.
|
||||
"""
|
||||
# Create batch from our buffers.
|
||||
batch = SampleBatch(self.buffers, seq_lens=self.seq_lens)
|
||||
# Clear buffers for future samples.
|
||||
self.buffers.clear()
|
||||
batch = SampleBatch.concat_samples(self.batches)
|
||||
# Clear batches for future samples.
|
||||
self.batches = []
|
||||
# Reset agent steps to 0 and seq-lens to empty list.
|
||||
self.agent_steps = 0
|
||||
if self.seq_lens is not None:
|
||||
self.seq_lens = []
|
||||
return batch
|
||||
|
||||
|
||||
|
@ -479,7 +563,7 @@ class SimpleListCollector(SampleCollector):
|
|||
# Add initial obs to Trajectory.
|
||||
assert agent_key not in self.agent_collectors
|
||||
# TODO: determine exact shift-before based on the view-req shifts.
|
||||
self.agent_collectors[agent_key] = _AgentCollector(view_reqs)
|
||||
self.agent_collectors[agent_key] = _AgentCollector(view_reqs, policy)
|
||||
self.agent_collectors[agent_key].add_init_obs(
|
||||
episode_id=episode.episode_id,
|
||||
agent_index=episode._agent_index(agent_id),
|
||||
|
@ -537,7 +621,13 @@ class SimpleListCollector(SampleCollector):
|
|||
Dict[str, TensorType]:
|
||||
policy = self.policy_map[policy_id]
|
||||
keys = self.forward_pass_agent_keys[policy_id]
|
||||
buffers = {k: self.agent_collectors[k].buffers for k in keys}
|
||||
|
||||
buffers = {}
|
||||
for k in keys:
|
||||
collector = self.agent_collectors[k]
|
||||
buffers[k] = collector.buffers
|
||||
# Use one agent's buffer_structs (they should all be the same).
|
||||
buffer_structs = self.agent_collectors[keys[0]].buffer_structs
|
||||
|
||||
input_dict = {}
|
||||
for view_col, view_req in policy.view_requirements.items():
|
||||
|
@ -558,33 +648,57 @@ class SimpleListCollector(SampleCollector):
|
|||
# Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0].
|
||||
else:
|
||||
time_indices = view_req.shift + delta
|
||||
data_list = []
|
||||
# Loop through agents and add-up their data (batch).
|
||||
|
||||
# Loop through agents and add up their data (batch).
|
||||
data = None
|
||||
for k in keys:
|
||||
if data_col == SampleBatch.EPS_ID:
|
||||
data_list.append(self.agent_collectors[k].episode_id)
|
||||
else:
|
||||
if data_col not in buffers[k]:
|
||||
if view_req.data_col is not None:
|
||||
space = policy.view_requirements[
|
||||
view_req.data_col].space
|
||||
else:
|
||||
space = view_req.space
|
||||
fill_value = np.zeros_like(space.sample()) \
|
||||
if isinstance(space, Space) else space
|
||||
self.agent_collectors[k]._build_buffers({
|
||||
data_col: fill_value
|
||||
})
|
||||
if isinstance(time_indices, tuple):
|
||||
if time_indices[1] == -1:
|
||||
data_list.append(
|
||||
buffers[k][data_col][time_indices[0]:])
|
||||
else:
|
||||
data_list.append(buffers[k][data_col][time_indices[
|
||||
0]:time_indices[1] + 1])
|
||||
# Buffer for the data does not exist yet: Create dummy
|
||||
# (zero) data.
|
||||
if data_col not in buffers[k]:
|
||||
if view_req.data_col is not None:
|
||||
space = policy.view_requirements[
|
||||
view_req.data_col].space
|
||||
else:
|
||||
data_list.append(buffers[k][data_col][time_indices])
|
||||
input_dict[view_col] = np.array(data_list)
|
||||
space = view_req.space
|
||||
|
||||
if isinstance(space, Space):
|
||||
fill_value = get_dummy_batch_for_space(
|
||||
space,
|
||||
batch_size=0,
|
||||
)
|
||||
else:
|
||||
fill_value = space
|
||||
|
||||
self.agent_collectors[k]._build_buffers({
|
||||
data_col: fill_value
|
||||
})
|
||||
|
||||
if data is None:
|
||||
data = [[] for _ in range(len(buffers[keys[0]][data_col]))]
|
||||
|
||||
# `shift_from` and `shift_to` are defined: User wants a
|
||||
# view with some time-range.
|
||||
if isinstance(time_indices, tuple):
|
||||
# `shift_to` == -1: Until the end (including(!) the
|
||||
# last item).
|
||||
if time_indices[1] == -1:
|
||||
for d, b in zip(data, buffers[k][data_col]):
|
||||
d.append(b[time_indices[0]:])
|
||||
# `shift_to` != -1: "Normal" range.
|
||||
else:
|
||||
for d, b in zip(data, buffers[k][data_col]):
|
||||
d.append(b[time_indices[0]:time_indices[1] + 1])
|
||||
# Single index.
|
||||
else:
|
||||
for d, b in zip(data, buffers[k][data_col]):
|
||||
d.append(b[time_indices])
|
||||
|
||||
np_data = [np.array(d) for d in data]
|
||||
if data_col in buffer_structs:
|
||||
input_dict[view_col] = tree.unflatten_as(
|
||||
buffer_structs[data_col], np_data)
|
||||
else:
|
||||
input_dict[view_col] = np_data[0]
|
||||
|
||||
self._reset_inference_calls(policy_id)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
|
|||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import NoPreprocessor, Preprocessor
|
||||
from ray.rllib.models.preprocessors import Preprocessor
|
||||
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
|
||||
OffPolicyEstimate
|
||||
|
@ -44,7 +44,7 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
|||
from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \
|
||||
ModelConfigDict, ModelGradients, ModelWeights, \
|
||||
MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \
|
||||
SampleBatchType, TrainerConfigDict
|
||||
SampleBatchType
|
||||
from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
enable_periodic_logging
|
||||
from ray.util.iter import ParallelIteratorWorker
|
||||
|
@ -168,7 +168,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
env_creator: Callable[[EnvContext], EnvType],
|
||||
validate_env: Optional[Callable[[EnvType, EnvContext],
|
||||
None]] = None,
|
||||
policy_spec: Union[type, Dict[PolicyID, PolicySpec]] = None,
|
||||
policy_spec: Optional[Union[type, Dict[PolicyID,
|
||||
PolicySpec]]] = None,
|
||||
policy_mapping_fn: Optional[Callable[
|
||||
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
|
||||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
|
@ -176,24 +177,24 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
rollout_fragment_length: int = 100,
|
||||
count_steps_by: str = "env_steps",
|
||||
batch_mode: str = "truncate_episodes",
|
||||
episode_horizon: int = None,
|
||||
preprocessor_pref: Optional[str] = "deepmind",
|
||||
episode_horizon: Optional[int] = None,
|
||||
preprocessor_pref: str = "deepmind",
|
||||
sample_async: bool = False,
|
||||
compress_observations: bool = False,
|
||||
num_envs: int = 1,
|
||||
observation_fn: "ObservationFunction" = None,
|
||||
observation_fn: Optional["ObservationFunction"] = None,
|
||||
observation_filter: str = "NoFilter",
|
||||
clip_rewards: Optional[Union[bool, float]] = None,
|
||||
normalize_actions: bool = True,
|
||||
clip_actions: bool = False,
|
||||
env_config: EnvConfigDict = None,
|
||||
model_config: ModelConfigDict = None,
|
||||
policy_config: TrainerConfigDict = None,
|
||||
env_config: Optional[EnvConfigDict] = None,
|
||||
model_config: Optional[ModelConfigDict] = None,
|
||||
policy_config: Optional[PartialTrainerConfigDict] = None,
|
||||
worker_index: int = 0,
|
||||
num_workers: int = 0,
|
||||
record_env: Union[bool, str] = False,
|
||||
log_dir: str = None,
|
||||
log_level: str = None,
|
||||
log_dir: Optional[str] = None,
|
||||
log_level: Optional[str] = None,
|
||||
callbacks: Type["DefaultCallbacks"] = None,
|
||||
input_creator: Callable[[
|
||||
IOContext
|
||||
|
@ -206,7 +207,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
soft_horizon: bool = False,
|
||||
no_done_at_end: bool = False,
|
||||
seed: int = None,
|
||||
extra_python_environs: dict = None,
|
||||
extra_python_environs: Optional[dict] = None,
|
||||
fake_sampler: bool = False,
|
||||
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
|
||||
gym.spaces.Space]]] = None,
|
||||
|
@ -258,10 +259,10 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
that when `num_envs > 1`, episode steps will be buffered
|
||||
until the episode completes, and hence batches may contain
|
||||
significant amounts of off-policy data.
|
||||
episode_horizon (int): Whether to stop episodes at this horizon.
|
||||
preprocessor_pref (Optional[str]): Whether to use no preprocessor
|
||||
(None), RLlib preprocessors ("rllib") or deepmind ("deepmind"),
|
||||
when applicable.
|
||||
episode_horizon: Horizon at which to stop episodes (even if the
|
||||
environment itself has not retured a "done" signal).
|
||||
preprocessor_pref (str): Whether to use RLlib preprocessors
|
||||
("rllib") or deepmind ("deepmind"), when applicable.
|
||||
sample_async (bool): Whether to compute samples asynchronously in
|
||||
the background, which improves throughput but can cause samples
|
||||
to be slightly off-policy.
|
||||
|
@ -284,9 +285,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
env_config (EnvConfigDict): Config to pass to the env creator.
|
||||
model_config (ModelConfigDict): Config to use when creating the
|
||||
policy model.
|
||||
policy_config (TrainerConfigDict): Config to pass to the policy.
|
||||
In the multi-agent case, this config will be merged with the
|
||||
per-policy configs specified by `policy_spec`.
|
||||
policy_config: Config to pass to the
|
||||
policy. In the multi-agent case, this config will be merged
|
||||
with the per-policy configs specified by `policy_spec`.
|
||||
worker_index (int): For remote workers, this should be set to a
|
||||
non-zero and unique value. This index is passed to created envs
|
||||
through EnvContext so that envs can be configured per worker.
|
||||
|
@ -378,7 +379,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
|
||||
ParallelIteratorWorker.__init__(self, gen_rollouts, False)
|
||||
|
||||
policy_config: TrainerConfigDict = policy_config or {}
|
||||
policy_config = policy_config or {}
|
||||
if (tf1 and policy_config.get("framework") in ["tf2", "tfe"]
|
||||
# This eager check is necessary for certain all-framework tests
|
||||
# that use tf's eager_mode() context generator.
|
||||
|
@ -400,7 +401,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
num_workers=num_workers,
|
||||
)
|
||||
self.env_context = env_context
|
||||
self.policy_config: TrainerConfigDict = policy_config
|
||||
self.policy_config: PartialTrainerConfigDict = policy_config
|
||||
if callbacks:
|
||||
self.callbacks: "DefaultCallbacks" = callbacks()
|
||||
else:
|
||||
|
@ -424,10 +425,10 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.batch_mode: str = batch_mode
|
||||
self.compress_observations: bool = compress_observations
|
||||
self.preprocessing_enabled: bool = False \
|
||||
if preprocessor_pref is None else True
|
||||
if policy_config.get("_disable_preprocessor_api") else True
|
||||
self.observation_filter = observation_filter
|
||||
self.last_batch: SampleBatchType = None
|
||||
self.global_vars: dict = None
|
||||
self.last_batch: Optional[SampleBatchType] = None
|
||||
self.global_vars: Optional[dict] = None
|
||||
self.fake_sampler: bool = fake_sampler
|
||||
|
||||
# Update the global seed for numpy/random/tf-eager/torch if we are not
|
||||
|
@ -1076,10 +1077,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
space of the policy to add.
|
||||
action_space (Optional[gym.spaces.Space]): The action space
|
||||
of the policy to add.
|
||||
config (Optional[PartialTrainerConfigDict]): The config
|
||||
overrides for the policy to add.
|
||||
policy_config (Optional[TrainerConfigDict]): The base config of the
|
||||
Trainer object owning this RolloutWorker.
|
||||
config: The config overrides for the policy to add.
|
||||
policy_config: The base config of the Trainer object owning this
|
||||
RolloutWorker.
|
||||
policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode],
|
||||
PolicyID]]): An optional (updated) policy mapping function to
|
||||
use from here on. Note that already ongoing episodes will not
|
||||
|
@ -1340,7 +1340,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
def _build_policy_map(
|
||||
self,
|
||||
policy_dict: MultiAgentPolicyConfigDict,
|
||||
policy_config: TrainerConfigDict,
|
||||
policy_config: PartialTrainerConfigDict,
|
||||
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]:
|
||||
|
@ -1371,13 +1371,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
if preprocessor is not None:
|
||||
obs_space = preprocessor.observation_space
|
||||
else:
|
||||
self.preprocessors[name] = NoPreprocessor(obs_space)
|
||||
|
||||
if isinstance(obs_space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
||||
raise ValueError(
|
||||
"Found raw Tuple|Dict space as input to policy. "
|
||||
"Please preprocess these observations with a "
|
||||
"Tuple|DictFlatteningPreprocessor.")
|
||||
self.preprocessors[name] = None
|
||||
|
||||
self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
|
||||
conf, merged_conf)
|
||||
|
|
|
@ -275,14 +275,14 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
rollout_fragment_length: int,
|
||||
count_steps_by: str = "env_steps",
|
||||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
horizon: Optional[int] = None,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
normalize_actions: bool = True,
|
||||
clip_actions: bool = False,
|
||||
blackhole_outputs: bool = False,
|
||||
soft_horizon: bool = False,
|
||||
no_done_at_end: bool = False,
|
||||
observation_fn: "ObservationFunction" = None,
|
||||
observation_fn: Optional["ObservationFunction"] = None,
|
||||
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
||||
render: bool = False,
|
||||
# Obsolete.
|
||||
|
@ -308,7 +308,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
Refers to the unit of `rollout_fragment_length`.
|
||||
callbacks (Callbacks): The Callbacks object to use when episode
|
||||
events happen during rollout.
|
||||
horizon (Optional[int]): Hard-reset the Env
|
||||
horizon: Hard-reset the Env after this many timesteps.
|
||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be
|
||||
exactly `rollout_fragment_length` in size.
|
||||
|
@ -452,7 +452,7 @@ def _env_runner(
|
|||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
extra_batch_callback: Callable[[SampleBatchType], None],
|
||||
horizon: int,
|
||||
horizon: Optional[int],
|
||||
normalize_actions: bool,
|
||||
clip_actions: bool,
|
||||
multiple_episodes_in_batch: bool,
|
||||
|
@ -470,7 +470,7 @@ def _env_runner(
|
|||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
base_env (BaseEnv): Env implementing BaseEnv.
|
||||
extra_batch_callback (fn): function to send extra batch data to.
|
||||
horizon (int): Horizon of the episode.
|
||||
horizon: Horizon of the episode.
|
||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be exactly
|
||||
`rollout_fragment_length` in size.
|
||||
|
|
|
@ -117,21 +117,24 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
ev.stop()
|
||||
|
||||
def test_batch_ids(self):
|
||||
fragment_len = 100
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=1)
|
||||
rollout_fragment_length=fragment_len)
|
||||
batch1 = ev.sample()
|
||||
batch2 = ev.sample()
|
||||
self.assertEqual(len(set(batch1["unroll_id"])), 1)
|
||||
self.assertEqual(len(set(batch2["unroll_id"])), 1)
|
||||
self.assertEqual(
|
||||
len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
|
||||
unroll_ids_1 = set(batch1["unroll_id"])
|
||||
unroll_ids_2 = set(batch2["unroll_id"])
|
||||
# Assert no overlap of unroll IDs between sample() calls.
|
||||
self.assertTrue(not any(uid in unroll_ids_2 for uid in unroll_ids_1))
|
||||
# CartPole episodes should be short initially: Expect more than one
|
||||
# unroll ID in each batch.
|
||||
self.assertTrue(len(unroll_ids_1) > 1)
|
||||
self.assertTrue(len(unroll_ids_2) > 1)
|
||||
ev.stop()
|
||||
|
||||
def test_global_vars_update(self):
|
||||
# Allow for Unittest run.
|
||||
ray.init(num_cpus=5, ignore_reinit_error=True)
|
||||
for fw in framework_iterator(frameworks=("tf2", "tf")):
|
||||
agent = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
|
|
|
@ -185,31 +185,49 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
policy_mapping_fn=None,
|
||||
num_envs=1,
|
||||
)
|
||||
# Add the next action to the view reqs of the policy.
|
||||
# Add the next action (a') and 2nd next action (a'') to the view
|
||||
# requirements of the policy.
|
||||
# This should be visible then in postprocessing and train batches.
|
||||
# Switch off for action computations (can't be there as we don't know
|
||||
# the next action already at action computation time).
|
||||
# the next actions already at action computation time).
|
||||
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
|
||||
"next_actions"] = ViewRequirement(
|
||||
SampleBatch.ACTIONS,
|
||||
shift=1,
|
||||
space=action_space,
|
||||
used_for_compute_actions=False)
|
||||
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
|
||||
"2nd_next_actions"] = ViewRequirement(
|
||||
SampleBatch.ACTIONS,
|
||||
shift=2,
|
||||
space=action_space,
|
||||
used_for_compute_actions=False)
|
||||
|
||||
# Make sure, we have DONEs as well.
|
||||
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
|
||||
"dones"] = ViewRequirement()
|
||||
batch = rollout_worker_w_api.sample()
|
||||
self.assertTrue("next_actions" in batch)
|
||||
self.assertTrue("2nd_next_actions" in batch)
|
||||
expected_a_ = None # expected next action
|
||||
expected_a__ = None # expected 2nd next action
|
||||
for i in range(len(batch["actions"])):
|
||||
a, d, a_ = batch["actions"][i], batch["dones"][i], \
|
||||
batch["next_actions"][i]
|
||||
if not d and expected_a_ is not None:
|
||||
check(a, expected_a_)
|
||||
elif d:
|
||||
a, d, a_, a__ = \
|
||||
batch["actions"][i], batch["dones"][i], \
|
||||
batch["next_actions"][i], batch["2nd_next_actions"][i]
|
||||
# Episode done: next action and 2nd next action should be 0.
|
||||
if d:
|
||||
check(a_, 0)
|
||||
check(a__, 0)
|
||||
expected_a_ = None
|
||||
expected_a__ = None
|
||||
continue
|
||||
# Episode is not done and we have an expected next-a.
|
||||
if expected_a_ is not None:
|
||||
check(a, expected_a_)
|
||||
if expected_a__ is not None:
|
||||
check(a_, expected_a__)
|
||||
expected_a__ = a__
|
||||
expected_a_ = a_
|
||||
|
||||
def test_traj_view_lstm_functionality(self):
|
||||
|
|
|
@ -57,7 +57,7 @@ class MyCallbacks(DefaultCallbacks):
|
|||
env_index: int, **kwargs):
|
||||
# Make sure this episode is really done.
|
||||
assert episode.batch_builder.policy_collectors[
|
||||
"default_policy"].buffers["dones"][-1], \
|
||||
"default_policy"].batches[-1]["dones"][-1], \
|
||||
"ERROR: `on_episode_end()` should only be called " \
|
||||
"after episode is done!"
|
||||
pole_angle = np.mean(episode.user_data["pole_angles"])
|
||||
|
|
|
@ -49,6 +49,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
|||
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||
SampleBatch.REWARDS: ViewRequirement(),
|
||||
SampleBatch.DONES: ViewRequirement(),
|
||||
SampleBatch.UNROLL_ID: ViewRequirement(),
|
||||
},
|
||||
**self.model.view_requirements)
|
||||
|
||||
|
|
107
rllib/examples/preprocessing_disabled.py
Normal file
107
rllib/examples/preprocessing_disabled.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
"""
|
||||
Example for using _disable_preprocessor_api=True to disable all preprocessing.
|
||||
|
||||
This example shows:
|
||||
- How a complex observation space from the env is handled directly by the
|
||||
model.
|
||||
- Complex observations are flattened into lists of tensors and as such
|
||||
stored by the SampleCollectors.
|
||||
- This has the advantage that preprocessing happens in batched fashion
|
||||
(in the model).
|
||||
"""
|
||||
import argparse
|
||||
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
|
||||
def get_cli_args():
|
||||
"""Create CLI parser and return parsed arguments"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# general args
|
||||
parser.add_argument(
|
||||
"--run", default="PPO", help="The RLlib-registered algorithm to use.")
|
||||
parser.add_argument("--num-cpus", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--framework",
|
||||
choices=["tf", "tf2", "tfe", "torch"],
|
||||
default="tf",
|
||||
help="The DL framework specifier.")
|
||||
parser.add_argument(
|
||||
"--stop-iters",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Number of iterations to train.")
|
||||
parser.add_argument(
|
||||
"--stop-timesteps",
|
||||
type=int,
|
||||
default=500000,
|
||||
help="Number of timesteps to train.")
|
||||
parser.add_argument(
|
||||
"--stop-reward",
|
||||
type=float,
|
||||
default=80.0,
|
||||
help="Reward at which we stop training.")
|
||||
parser.add_argument(
|
||||
"--as-test",
|
||||
action="store_true",
|
||||
help="Whether this script should be run as a test: --stop-reward must "
|
||||
"be achieved within --stop-timesteps AND --stop-iters.")
|
||||
parser.add_argument(
|
||||
"--no-tune",
|
||||
action="store_true",
|
||||
help="Run without Tune using a manual train loop instead. Here,"
|
||||
"there is no TensorBoard support.")
|
||||
parser.add_argument(
|
||||
"--local-mode",
|
||||
action="store_true",
|
||||
help="Init Ray in local mode for easier debugging.")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(f"Running with following CLI args: {args}")
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_cli_args()
|
||||
|
||||
ray.init(local_mode=args.local_mode)
|
||||
|
||||
config = {
|
||||
"env": "ray.rllib.examples.env.random_env.RandomEnv",
|
||||
"env_config": {
|
||||
"config": {
|
||||
"observation_space": Dict({
|
||||
"a": Discrete(2),
|
||||
"b": Dict({
|
||||
"ba": Discrete(3),
|
||||
"bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32)
|
||||
}),
|
||||
"c": Tuple((MultiDiscrete([2, 3]), Discrete(2))),
|
||||
"d": Box(-1.0, 1.0, (2, ), dtype=np.int32),
|
||||
}),
|
||||
},
|
||||
},
|
||||
# Set this to True to enforce no preprocessors being used.
|
||||
# Complex observations now arrive directly in the model as
|
||||
# structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]}
|
||||
# for obs-space=Dict(a=..., b=Tuple(..., ...)).
|
||||
"_disable_preprocessor_api": True,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", 0)),
|
||||
"framework": args.framework,
|
||||
}
|
||||
|
||||
stop = {
|
||||
"training_iteration": args.stop_iters,
|
||||
"timesteps_total": args.stop_timesteps,
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=2)
|
||||
|
||||
ray.shutdown()
|
|
@ -502,7 +502,8 @@ class LocalReplayBuffer(ParallelIteratorWorker):
|
|||
# If SampleBatch has prio-replay weights, average
|
||||
# over these to use as a weight for the entire
|
||||
# sequence.
|
||||
if "weights" in time_slice:
|
||||
if "weights" in time_slice and \
|
||||
len(time_slice["weights"]):
|
||||
weight = np.mean(time_slice["weights"])
|
||||
else:
|
||||
weight = None
|
||||
|
|
|
@ -46,9 +46,9 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
|||
"_use_default_native_models": False,
|
||||
# Experimental flag.
|
||||
# If True, user specified no preprocessor to be created
|
||||
# (via config.preprocessor_pref=None). If True, observations will arrive
|
||||
# in model as they are returned by the env.
|
||||
"_no_preprocessing": False,
|
||||
# (via config._disable_preprocessor_api=True). If True, observations
|
||||
# will arrive in model as they are returned by the env.
|
||||
"_disable_preprocessor_api": False,
|
||||
|
||||
# === Built-in options ===
|
||||
# FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
|
||||
|
|
|
@ -217,12 +217,12 @@ class ModelV2:
|
|||
else:
|
||||
restored = input_dict.copy()
|
||||
|
||||
# No Preprocessor used: `config.preprocessor_pref`=None.
|
||||
# No Preprocessor used: `config._disable_preprocessor_api`=True.
|
||||
# TODO: This is unnecessary for when no preprocessor is used.
|
||||
# Obs are not flat then anymore. However, we'll keep this
|
||||
# here for backward-compatibility until Preprocessors have
|
||||
# been fully deprecated.
|
||||
if self.model_config.get("_no_preprocessing"):
|
||||
if self.model_config.get("_disable_preprocessor_api"):
|
||||
restored["obs_flat"] = input_dict["obs"]
|
||||
# Input to this Model went through a Preprocessor.
|
||||
# Generate extra keys: "obs_flat" (vs "obs", which will hold the
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import List
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.framework import TensorType, TensorStructType
|
||||
from ray.rllib.utils.typing import TensorType, TensorStructType
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
|
|
@ -3,14 +3,57 @@ from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
|
|||
import numpy as np
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import DictFlatteningPreprocessor, \
|
||||
get_preprocessor, NoPreprocessor, TupleFlatteningPreprocessor, \
|
||||
OneHotPreprocessor, AtariRamPreprocessor, GenericPixelPreprocessor
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
|
||||
framework_iterator
|
||||
|
||||
|
||||
class TestPreprocessors(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_preprocessing_disabled(self):
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
|
||||
config["env"] = "ray.rllib.examples.env.random_env.RandomEnv"
|
||||
config["env_config"] = {
|
||||
"config": {
|
||||
"observation_space": Dict({
|
||||
"a": Discrete(5),
|
||||
"b": Dict({
|
||||
"ba": Discrete(4),
|
||||
"bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32)
|
||||
}),
|
||||
"c": Tuple((MultiDiscrete([2, 3]), Discrete(1))),
|
||||
"d": Box(-1.0, 1.0, (1, ), dtype=np.int32),
|
||||
}),
|
||||
},
|
||||
}
|
||||
# Set this to True to enforce no preprocessors being used.
|
||||
# Complex observations now arrive directly in the model as
|
||||
# structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]}
|
||||
# for obs-space=Dict(a=..., b=Tuple(..., ...)).
|
||||
config["_disable_preprocessor_api"] = True
|
||||
|
||||
num_iterations = 1
|
||||
# Only supported for tf so far.
|
||||
for _ in framework_iterator(config):
|
||||
trainer = ppo.PPOTrainer(config=config)
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
|
||||
def test_gym_preprocessors(self):
|
||||
p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0"))
|
||||
self.assertEqual(type(p1), NoPreprocessor)
|
||||
|
|
|
@ -133,7 +133,7 @@ class ComplexInputNetwork(TFModelV2):
|
|||
cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
|
||||
outs.append(cnn_out)
|
||||
elif i in self.one_hot:
|
||||
if component.dtype in [tf.int32, tf.int64, tf.uint8]:
|
||||
if "int" in component.dtype.name:
|
||||
outs.append(
|
||||
one_hot(component, self.flattened_input_space[i]))
|
||||
else:
|
||||
|
|
|
@ -546,6 +546,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
# Skip action dist inputs placeholder (do later).
|
||||
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
|
||||
continue
|
||||
# This is a tower, input placeholders already exist.
|
||||
elif view_col in existing_inputs:
|
||||
input_dict[view_col] = existing_inputs[view_col]
|
||||
# All others.
|
||||
|
@ -554,10 +555,15 @@ class DynamicTFPolicy(TFPolicy):
|
|||
if view_req.used_for_training:
|
||||
# Create a +time-axis placeholder if the shift is not an
|
||||
# int (range or list of ints).
|
||||
flatten = view_col not in [
|
||||
SampleBatch.OBS, SampleBatch.NEXT_OBS] or \
|
||||
not self.config["_disable_preprocessor_api"]
|
||||
input_dict[view_col] = get_placeholder(
|
||||
space=view_req.space,
|
||||
name=view_col,
|
||||
time_axis=time_axis)
|
||||
time_axis=time_axis,
|
||||
flatten=flatten,
|
||||
)
|
||||
dummy_batch = self._get_dummy_batch_from_view_requirements(
|
||||
batch_size=32)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ It supports both traced and non-traced eager execution modes."""
|
|||
import functools
|
||||
import logging
|
||||
import threading
|
||||
import tree # pip install dm_tree
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from ray.util.debug import log_once
|
||||
|
@ -425,10 +426,12 @@ def build_eager_tf_policy(
|
|||
if not tf1.executing_eagerly():
|
||||
tf1.enable_eager_execution()
|
||||
|
||||
input_dict = {
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
||||
"is_training": tf.constant(False),
|
||||
}
|
||||
input_dict = SampleBatch(
|
||||
{
|
||||
SampleBatch.CUR_OBS: tree.map_structure(
|
||||
lambda s: tf.convert_to_tensor(s), obs_batch),
|
||||
},
|
||||
_is_training=tf.constant(False))
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||
tf.convert_to_tensor(prev_action_batch)
|
||||
|
@ -478,7 +481,7 @@ def build_eager_tf_policy(
|
|||
self._is_training = False
|
||||
self._state_in = state_batches or []
|
||||
# Calculate RNN sequence lengths.
|
||||
batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
|
||||
batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
|
||||
seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \
|
||||
else None
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import gym
|
|||
from gym.spaces import Box
|
||||
import logging
|
||||
import numpy as np
|
||||
import tree # pip install dm_tree
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
@ -17,7 +18,7 @@ from ray.rllib.utils.spaces.space_utils import clip_action, \
|
|||
get_base_struct_from_space, get_dummy_batch_for_space, unbatch, \
|
||||
unsquash_action
|
||||
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
|
||||
TensorType, TrainerConfigDict, Tuple, Union
|
||||
TensorType, TensorStructType, TrainerConfigDict, Tuple, Union
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -132,10 +133,12 @@ class Policy(metaclass=ABCMeta):
|
|||
@DeveloperAPI
|
||||
def compute_actions(
|
||||
self,
|
||||
obs_batch: Union[List[TensorType], TensorType],
|
||||
obs_batch: Union[List[TensorStructType], TensorStructType],
|
||||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Union[List[TensorType], TensorType] = None,
|
||||
prev_reward_batch: Union[List[TensorType], TensorType] = None,
|
||||
prev_action_batch: Union[List[TensorStructType],
|
||||
TensorStructType] = None,
|
||||
prev_reward_batch: Union[List[TensorStructType],
|
||||
TensorStructType] = None,
|
||||
info_batch: Optional[Dict[str, list]] = None,
|
||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
||||
explore: Optional[bool] = None,
|
||||
|
@ -145,14 +148,10 @@ class Policy(metaclass=ABCMeta):
|
|||
"""Computes actions for the current policy.
|
||||
|
||||
Args:
|
||||
obs_batch (Union[List[TensorType], TensorType]): Batch of
|
||||
observations.
|
||||
state_batches (Optional[List[TensorType]]): List of RNN state input
|
||||
batches, if any.
|
||||
prev_action_batch (Union[List[TensorType], TensorType]): Batch of
|
||||
previous action values.
|
||||
prev_reward_batch (Union[List[TensorType], TensorType]): Batch of
|
||||
previous rewards.
|
||||
obs_batch: Batch of observations.
|
||||
state_batches: List of RNN state input batches, if any.
|
||||
prev_action_batch: Batch of previous action values.
|
||||
prev_reward_batch: Batch of previous rewards.
|
||||
info_batch (Optional[Dict[str, list]]): Batch of info objects.
|
||||
episodes (Optional[List[MultiAgentEpisode]] ): List of
|
||||
MultiAgentEpisode, one for each obs in obs_batch. This provides
|
||||
|
@ -181,10 +180,10 @@ class Policy(metaclass=ABCMeta):
|
|||
@DeveloperAPI
|
||||
def compute_single_action(
|
||||
self,
|
||||
obs: TensorType,
|
||||
obs: TensorStructType,
|
||||
state: Optional[List[TensorType]] = None,
|
||||
prev_action: Optional[TensorType] = None,
|
||||
prev_reward: Optional[TensorType] = None,
|
||||
prev_action: Optional[TensorStructType] = None,
|
||||
prev_reward: Optional[TensorStructType] = None,
|
||||
info: dict = None,
|
||||
episode: Optional["MultiAgentEpisode"] = None,
|
||||
clip_actions: bool = None,
|
||||
|
@ -192,37 +191,34 @@ class Policy(metaclass=ABCMeta):
|
|||
timestep: Optional[int] = None,
|
||||
unsquash_actions: bool = None,
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
|
||||
"""Unbatched version of compute_actions.
|
||||
|
||||
Args:
|
||||
obs (TensorType): Single observation.
|
||||
state (Optional[List[TensorType]]): List of RNN state inputs, if
|
||||
any.
|
||||
prev_action (Optional[TensorType]): Previous action value, if any.
|
||||
prev_reward (Optional[TensorType]): Previous reward, if any.
|
||||
obs: Single observation.
|
||||
state: List of RNN state inputs, if any.
|
||||
prev_action: Previous action value, if any.
|
||||
prev_reward: Previous reward, if any.
|
||||
info (dict): Info object, if any.
|
||||
episode (Optional[MultiAgentEpisode]): this provides access to all
|
||||
episode: this provides access to all
|
||||
of the internal episode state, which may be useful for
|
||||
model-based or multi-agent algorithms.
|
||||
unsquash_actions (bool): Should actions be unsquashed according to
|
||||
unsquash_actions: Should actions be unsquashed according to
|
||||
the Policy's action space?
|
||||
clip_actions (bool): Should actions be clipped according to the
|
||||
clip_actions: Should actions be clipped according to the
|
||||
Policy's action space?
|
||||
explore (Optional[bool]): Whether to pick an exploitation or
|
||||
explore: Whether to pick an exploitation or
|
||||
exploration action
|
||||
(default: None -> use self.config["explore"]).
|
||||
timestep (Optional[int]): The current (sampling) time step.
|
||||
timestep: The current (sampling) time step.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: Forward compatibility.
|
||||
|
||||
Returns:
|
||||
Tuple:
|
||||
- actions (TensorType): Single action.
|
||||
- state_outs (List[TensorType]): List of RNN state outputs,
|
||||
if any.
|
||||
- info (dict): Dictionary of extra features, if any.
|
||||
- actions: Single action.
|
||||
- state_outs: List of RNN state outputs, if any.
|
||||
- info: Dictionary of extra features, if any.
|
||||
"""
|
||||
# If policy works in normalized space, we should unsquash the action.
|
||||
# Use value of config.normalize_actions, if None.
|
||||
|
@ -253,7 +249,7 @@ class Policy(metaclass=ABCMeta):
|
|||
]
|
||||
|
||||
out = self.compute_actions(
|
||||
[obs],
|
||||
tree.map_structure(lambda s: np.array([s]), obs),
|
||||
state_batch,
|
||||
prev_action_batch=prev_action_batch,
|
||||
prev_reward_batch=prev_reward_batch,
|
||||
|
@ -805,7 +801,8 @@ class Policy(metaclass=ABCMeta):
|
|||
for key in all_accessed_keys:
|
||||
if key not in self.view_requirements and \
|
||||
key != SampleBatch.SEQ_LENS:
|
||||
self.view_requirements[key] = ViewRequirement()
|
||||
self.view_requirements[key] = ViewRequirement(
|
||||
used_for_compute_actions=False)
|
||||
if self._loss:
|
||||
# Tag those only needed for post-processing (with some
|
||||
# exceptions).
|
||||
|
@ -852,7 +849,7 @@ class Policy(metaclass=ABCMeta):
|
|||
"""
|
||||
ret = {}
|
||||
for view_col, view_req in self.view_requirements.items():
|
||||
if self.config["preprocessor_pref"] is not None and \
|
||||
if not self.config["_disable_preprocessor_api"] and \
|
||||
isinstance(view_req.space,
|
||||
(gym.spaces.Dict, gym.spaces.Tuple)):
|
||||
_, shape = ModelCatalog.get_action_shape(
|
||||
|
|
|
@ -364,6 +364,9 @@ class SampleBatch(dict):
|
|||
for i, p in enumerate(path):
|
||||
if i == len(path) - 1:
|
||||
curr[p] = value[permutation]
|
||||
# Translate into list (tuples are immutable).
|
||||
if isinstance(curr[p], tuple):
|
||||
curr[p] = list(curr[p])
|
||||
curr = curr[p]
|
||||
|
||||
tree.map_structure_with_path(_permutate_in_place, self)
|
||||
|
|
|
@ -4,6 +4,7 @@ import logging
|
|||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import tree # pip install dm_tree
|
||||
from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
import ray
|
||||
|
@ -435,15 +436,16 @@ class TFPolicy(Policy):
|
|||
fetched = builder.get(to_fetch)
|
||||
|
||||
# Update our global timestep by the batch size.
|
||||
self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
|
||||
else obs_batch.shape[0]
|
||||
self.global_timestep += \
|
||||
len(obs_batch) if isinstance(obs_batch, list) \
|
||||
else tree.flatten(obs_batch)[0].shape[0]
|
||||
|
||||
return fetched
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions_from_input_dict(
|
||||
self,
|
||||
input_dict: Dict[str, TensorType],
|
||||
input_dict: Union[SampleBatch, Dict[str, TensorType]],
|
||||
explore: bool = None,
|
||||
timestep: Optional[int] = None,
|
||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
||||
|
@ -464,6 +466,7 @@ class TFPolicy(Policy):
|
|||
|
||||
# Update our global timestep by the batch size.
|
||||
self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
|
||||
else len(input_dict) if isinstance(input_dict, SampleBatch) \
|
||||
else obs_batch.shape[0]
|
||||
|
||||
return fetched
|
||||
|
@ -965,7 +968,11 @@ class TFPolicy(Policy):
|
|||
if hasattr(self, "_input_dict"):
|
||||
for key, value in input_dict.items():
|
||||
if key in self._input_dict:
|
||||
builder.add_feed_dict({self._input_dict[key]: value})
|
||||
# Handle complex/nested spaces as well.
|
||||
tree.map_structure(
|
||||
lambda k, v: builder.add_feed_dict({k: v}),
|
||||
self._input_dict[key], value,
|
||||
)
|
||||
# For policies that inherit directly from TFPolicy.
|
||||
else:
|
||||
builder.add_feed_dict({
|
||||
|
@ -1004,7 +1011,10 @@ class TFPolicy(Policy):
|
|||
"Must pass in RNN state batches for placeholders {}, "
|
||||
"got {}".format(self._state_inputs, state_batches))
|
||||
|
||||
builder.add_feed_dict({self._obs_input: obs_batch})
|
||||
tree.map_structure(
|
||||
lambda k, v: builder.add_feed_dict({k: v}),
|
||||
self._obs_input, obs_batch,
|
||||
)
|
||||
if state_batches:
|
||||
builder.add_feed_dict({
|
||||
self._seq_lens: np.ones(len(obs_batch))
|
||||
|
@ -1106,8 +1116,12 @@ class TFPolicy(Policy):
|
|||
|
||||
# Build the feed dict from the batch.
|
||||
feed_dict = {}
|
||||
for key, placeholder in self._loss_input_dict.items():
|
||||
feed_dict[placeholder] = train_batch[key]
|
||||
for key, placeholders in self._loss_input_dict.items():
|
||||
tree.map_structure(
|
||||
lambda ph, v: feed_dict.__setitem__(ph, v),
|
||||
placeholders,
|
||||
train_batch[key],
|
||||
)
|
||||
|
||||
state_keys = [
|
||||
"state_in_{}".format(i) for i in range(len(self._state_inputs))
|
||||
|
|
|
@ -26,7 +26,7 @@ from ray.rllib.utils.threading import with_lock
|
|||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
|
||||
convert_to_torch_tensor
|
||||
from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \
|
||||
TrainerConfigDict
|
||||
TensorStructType, TrainerConfigDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.evaluation import MultiAgentEpisode # noqa
|
||||
|
@ -246,23 +246,24 @@ class TorchPolicy(Policy):
|
|||
@DeveloperAPI
|
||||
def compute_actions(
|
||||
self,
|
||||
obs_batch: Union[List[TensorType], TensorType],
|
||||
obs_batch: Union[List[TensorStructType], TensorStructType],
|
||||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Union[List[TensorType], TensorType] = None,
|
||||
prev_reward_batch: Union[List[TensorType], TensorType] = None,
|
||||
prev_action_batch: Union[List[TensorStructType],
|
||||
TensorStructType] = None,
|
||||
prev_reward_batch: Union[List[TensorStructType],
|
||||
TensorStructType] = None,
|
||||
info_batch: Optional[Dict[str, list]] = None,
|
||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
||||
explore: Optional[bool] = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
|
||||
|
||||
with torch.no_grad():
|
||||
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
||||
input_dict = self._lazy_tensor_dict(
|
||||
SampleBatch({
|
||||
SampleBatch.CUR_OBS: np.asarray(obs_batch),
|
||||
}))
|
||||
input_dict = self._lazy_tensor_dict({
|
||||
SampleBatch.CUR_OBS: obs_batch
|
||||
})
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||
np.asarray(prev_action_batch)
|
||||
|
@ -405,13 +406,13 @@ class TorchPolicy(Policy):
|
|||
@DeveloperAPI
|
||||
def compute_log_likelihoods(
|
||||
self,
|
||||
actions: Union[List[TensorType], TensorType],
|
||||
obs_batch: Union[List[TensorType], TensorType],
|
||||
actions: Union[List[TensorStructType], TensorStructType],
|
||||
obs_batch: Union[List[TensorStructType], TensorStructType],
|
||||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Optional[Union[List[TensorType],
|
||||
TensorType]] = None,
|
||||
prev_reward_batch: Optional[Union[List[TensorType],
|
||||
TensorType]] = None,
|
||||
prev_action_batch: Optional[Union[List[TensorStructType],
|
||||
TensorStructType]] = None,
|
||||
prev_reward_batch: Optional[Union[List[TensorStructType],
|
||||
TensorStructType]] = None,
|
||||
actions_normalized: bool = True,
|
||||
) -> TensorType:
|
||||
|
||||
|
|
|
@ -35,10 +35,11 @@ class Filter:
|
|||
class NoFilter(Filter):
|
||||
is_concurrent = True
|
||||
|
||||
def __init__(self, *args):
|
||||
pass
|
||||
|
||||
def __call__(self, x, update=True):
|
||||
# Process no further if already np.ndarray, dict, or tuple.
|
||||
if isinstance(x, (np.ndarray, dict, tuple)):
|
||||
return x
|
||||
|
||||
try:
|
||||
return np.asarray(x)
|
||||
except Exception:
|
||||
|
|
|
@ -5,16 +5,10 @@ import sys
|
|||
from typing import Any, Optional
|
||||
|
||||
from ray.rllib.utils.annotations import Deprecated
|
||||
from ray.rllib.utils.typing import TensorStructType, TensorShape, TensorType
|
||||
from ray.rllib.utils.typing import TensorShape, TensorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Represents a generic tensor type.
|
||||
TensorType = TensorType
|
||||
|
||||
# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
|
||||
TensorStructType = TensorStructType
|
||||
|
||||
|
||||
def try_import_jax(error=False):
|
||||
"""Tries importing JAX and FLAX and returns both modules (or Nones).
|
||||
|
@ -199,36 +193,36 @@ def _torch_stubs():
|
|||
return None, nn
|
||||
|
||||
|
||||
def get_variable(value,
|
||||
def get_variable(value: Any,
|
||||
framework: str = "tf",
|
||||
trainable: bool = False,
|
||||
tf_name: str = "unnamed-variable",
|
||||
torch_tensor: bool = False,
|
||||
device: Optional[str] = None,
|
||||
shape: Optional[TensorShape] = None,
|
||||
dtype: Optional[Any] = None):
|
||||
dtype: Optional[TensorType] = None) -> Any:
|
||||
"""
|
||||
Args:
|
||||
value (any): The initial value to use. In the non-tf case, this will
|
||||
value: The initial value to use. In the non-tf case, this will
|
||||
be returned as is. In the tf case, this could be a tf-Initializer
|
||||
object.
|
||||
framework (str): One of "tf", "torch", or None.
|
||||
trainable (bool): Whether the generated variable should be
|
||||
framework: One of "tf", "torch", or None.
|
||||
trainable: Whether the generated variable should be
|
||||
trainable (tf)/require_grad (torch) or not (default: False).
|
||||
tf_name (str): For framework="tf": An optional name for the
|
||||
tf_name: For framework="tf": An optional name for the
|
||||
tf.Variable.
|
||||
torch_tensor (bool): For framework="torch": Whether to actually create
|
||||
torch_tensor: For framework="torch": Whether to actually create
|
||||
a torch.tensor, or just a python value (default).
|
||||
device (Optional[torch.Device]): An optional torch device to use for
|
||||
device: An optional torch device to use for
|
||||
the created torch tensor.
|
||||
shape (Optional[TensorShape]): An optional shape to use iff `value`
|
||||
shape: An optional shape to use iff `value`
|
||||
does not have any (e.g. if it's an initializer w/o explicit value).
|
||||
dtype (Optional[TensorType]): An optional dtype to use iff `value` does
|
||||
dtype: An optional dtype to use iff `value` does
|
||||
not have any (e.g. if it's an initializer w/o explicit value).
|
||||
This should always be a numpy dtype (e.g. np.float32, np.int64).
|
||||
|
||||
Returns:
|
||||
any: A framework-specific variable (tf.Variable, torch.tensor, or
|
||||
A framework-specific variable (tf.Variable, torch.tensor, or
|
||||
python primitive).
|
||||
"""
|
||||
if framework in ["tf2", "tf", "tfe"]:
|
||||
|
|
|
@ -4,6 +4,7 @@ import numpy as np
|
|||
import tree # pip install dm_tree
|
||||
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
||||
from ray.rllib.utils.typing import TensorStructType, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
@ -56,12 +57,26 @@ def get_gpu_devices():
|
|||
return [d.name for d in devices if "GPU" in d.device_type]
|
||||
|
||||
|
||||
def get_placeholder(*, space=None, value=None, name=None, time_axis=False):
|
||||
def get_placeholder(*,
|
||||
space=None,
|
||||
value=None,
|
||||
name=None,
|
||||
time_axis=False,
|
||||
flatten=True):
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
if space is not None:
|
||||
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
||||
return ModelCatalog.get_action_placeholder(space, None)
|
||||
if flatten:
|
||||
return ModelCatalog.get_action_placeholder(space, None)
|
||||
else:
|
||||
return tree.map_structure_with_path(
|
||||
lambda path, component: get_placeholder(
|
||||
space=component,
|
||||
name=name + "." + ".".join([str(p) for p in path]),
|
||||
),
|
||||
get_base_struct_from_space(space),
|
||||
)
|
||||
return tf1.placeholder(
|
||||
shape=(None, ) + ((None, ) if time_axis else ()) + space.shape,
|
||||
dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
|
||||
|
@ -138,10 +153,13 @@ def zero_logps_from_actions(actions: TensorStructType) -> TensorType:
|
|||
|
||||
def one_hot(x, space):
|
||||
if isinstance(space, Discrete):
|
||||
return tf.one_hot(x, space.n)
|
||||
return tf.one_hot(x, space.n, dtype=tf.float32)
|
||||
elif isinstance(space, MultiDiscrete):
|
||||
return tf.concat(
|
||||
[tf.one_hot(x[:, i], n) for i, n in enumerate(space.nvec)],
|
||||
[
|
||||
tf.one_hot(x[:, i], n, dtype=tf.float32)
|
||||
for i, n in enumerate(space.nvec)
|
||||
],
|
||||
axis=-1)
|
||||
else:
|
||||
raise ValueError("Unsupported space for `one_hot`: {}".format(space))
|
||||
|
@ -200,11 +218,12 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
|
|||
|
||||
def make_wrapper(fn):
|
||||
# Static-graph mode: Create placeholders and make a session call each
|
||||
# time the wrapped function is called. Return this session call's
|
||||
# outputs.
|
||||
# time the wrapped function is called. Returns the output of this
|
||||
# session call.
|
||||
if session_or_none is not None:
|
||||
args_placeholders = []
|
||||
kwargs_placeholders = {}
|
||||
|
||||
symbolic_out = [None]
|
||||
|
||||
def call(*args, **kwargs):
|
||||
|
@ -215,40 +234,42 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
|
|||
else:
|
||||
args_flat.append(a)
|
||||
args = args_flat
|
||||
|
||||
# We have not built any placeholders yet: Do this once here,
|
||||
# then reuse the same placeholders each time we call this
|
||||
# function again.
|
||||
if symbolic_out[0] is None:
|
||||
with session_or_none.graph.as_default():
|
||||
for i, v in enumerate(args):
|
||||
|
||||
def _create_placeholders(path, value):
|
||||
if dynamic_shape:
|
||||
if len(v.shape) > 0:
|
||||
shape = (None, ) + v.shape[1:]
|
||||
if len(value.shape) > 0:
|
||||
shape = (None, ) + value.shape[1:]
|
||||
else:
|
||||
shape = ()
|
||||
else:
|
||||
shape = v.shape
|
||||
args_placeholders.append(
|
||||
tf1.placeholder(
|
||||
dtype=v.dtype,
|
||||
shape=shape,
|
||||
name="arg_{}".format(i)))
|
||||
for k, v in kwargs.items():
|
||||
if dynamic_shape:
|
||||
if len(v.shape) > 0:
|
||||
shape = (None, ) + v.shape[1:]
|
||||
else:
|
||||
shape = ()
|
||||
else:
|
||||
shape = v.shape
|
||||
kwargs_placeholders[k] = \
|
||||
tf1.placeholder(
|
||||
dtype=v.dtype,
|
||||
shape=shape,
|
||||
name="kwarg_{}".format(k))
|
||||
shape = value.shape
|
||||
return tf1.placeholder(
|
||||
dtype=value.dtype,
|
||||
shape=shape,
|
||||
name=".".join([str(p) for p in path]),
|
||||
)
|
||||
|
||||
placeholders = tree.map_structure_with_path(
|
||||
_create_placeholders, args)
|
||||
for ph in tree.flatten(placeholders):
|
||||
args_placeholders.append(ph)
|
||||
|
||||
placeholders = tree.map_structure_with_path(
|
||||
_create_placeholders, kwargs)
|
||||
for k, ph in placeholders.items():
|
||||
kwargs_placeholders[k] = ph
|
||||
|
||||
symbolic_out[0] = fn(*args_placeholders,
|
||||
**kwargs_placeholders)
|
||||
feed_dict = dict(zip(args_placeholders, args))
|
||||
feed_dict.update(
|
||||
{kwargs_placeholders[k]: kwargs[k]
|
||||
for k in kwargs.keys()})
|
||||
feed_dict = dict(zip(args_placeholders, tree.flatten(args)))
|
||||
tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v),
|
||||
kwargs_placeholders, kwargs)
|
||||
ret = session_or_none.run(symbolic_out[0], feed_dict)
|
||||
return ret
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue