[RLlib] No Preprocessors (part 2). (#18468)

This commit is contained in:
Sven Mika 2021-09-23 12:56:45 +02:00 committed by GitHub
parent a2a077b874
commit 61a1274619
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 657 additions and 308 deletions

View file

@ -1460,7 +1460,7 @@ py_test(
py_test(
name = "test_preprocessors",
tags = ["team:ml", "models"],
size = "small",
size = "medium",
srcs = ["models/tests/test_preprocessors.py"]
)
@ -2659,6 +2659,24 @@ py_test(
srcs = ["examples/pettingzoo_env.py"],
)
py_test(
name = "examples/preprocessing_disabled_tf",
main = "examples/preprocessing_disabled.py",
tags = ["team:ml", "examples", "examples_P"],
size = "medium",
srcs = ["examples/preprocessing_disabled.py"],
args = ["--stop-iters=2"]
)
py_test(
name = "examples/preprocessing_disabled_torch",
main = "examples/preprocessing_disabled.py",
tags = ["team:ml", "examples", "examples_P"],
size = "medium",
srcs = ["examples/preprocessing_disabled.py"],
args = ["--framework=torch", "--stop-iters=2"]
)
py_test(
name = "examples/remote_envs_with_inference_done_on_main_node_tf",
main = "examples/remote_envs_with_inference_done_on_main_node.py",

View file

@ -30,12 +30,13 @@ from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override, \
from ray.rllib.utils.debug import update_global_seed_if_necessary
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR
from ray.rllib.utils.framework import try_import_tf, TensorStructType
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.multi_agent import check_multi_agent
from ray.rllib.utils.spaces import space_utils
from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \
PartialTrainerConfigDict, PolicyID, ResultDict, TrainerConfigDict
PartialTrainerConfigDict, PolicyID, ResultDict, TensorStructType, \
TrainerConfigDict
from ray.tune.logger import Logger, UnifiedLogger
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.resources import Resources
@ -113,11 +114,6 @@ COMMON_CONFIG: TrainerConfigDict = {
"model": MODEL_DEFAULTS,
# Arguments to pass to the policy optimizer. These vary by optimizer.
"optimizer": {},
# Experimental flag, indicating that TFPolicy will handle more than one
# loss/optimizer. Set this to True, if you would like to return more than
# one loss term from your `loss_fn` and an equal number of optimizers
# from your `optimizer_fn`.
"_tf_policy_handles_more_than_one_loss": False,
# === Environment Settings ===
# Number of steps after which the episode is forced to terminate. Defaults
@ -483,6 +479,20 @@ COMMON_CONFIG: TrainerConfigDict = {
# Default value None allows overwriting with nested dicts
"logger_config": None,
# === API deprecations/simplifications/changes ===
# Experimental flag.
# If True, TFPolicy will handle more than one loss/optimizer.
# Set this to True, if you would like to return more than
# one loss term from your `loss_fn` and an equal number of optimizers
# from your `optimizer_fn`.
# In the future, the default for this will be True.
"_tf_policy_handles_more_than_one_loss": False,
# Experimental flag.
# If True, no (observation) preprocessor will be created and
# observations will arrive in model as they are returned by the env.
# In the future, the default for this will be True.
"_disable_preprocessor_api": False,
# === Deprecated keys ===
# Uses the sync samples optimizer instead of the multi-gpu one. This is
# usually slower, but you might want to try it if you run into issues with
@ -1128,8 +1138,8 @@ class Trainer(Trainable):
tuple: The full output of policy.compute_actions() if
full_fetch=True or we have an RNN-based Policy.
"""
# Preprocess obs and states
stateDefined = state is not None
# Preprocess obs and states.
state_defined = state is not None
policy = self.get_policy(policy_id)
filtered_obs, filtered_state = [], []
for agent_id, ob in observations.items():
@ -1174,7 +1184,7 @@ class Trainer(Trainable):
unbatched_states[agent_id] = [s[idx] for s in states]
# Return only actions or full tuple
if stateDefined or full_fetch:
if state_defined or full_fetch:
return actions, unbatched_states, infos
else:
return actions
@ -1529,8 +1539,8 @@ class Trainer(Trainable):
# Check model config.
# If no preprocessing, propagate into model's config as well
# (so model will know, whether inputs are preprocessed or not).
if config["preprocessor_pref"] is None:
model_config["_no_preprocessor"] = True
if config["_disable_preprocessor_api"] is True:
model_config["_disable_preprocessor_api"] = True
# Prev_a/r settings.
prev_a_r = model_config.get("lstm_use_prev_action_reward",

View file

@ -3,6 +3,7 @@ from gym.spaces import Space
import logging
import math
import numpy as np
import tree # pip install dm_tree
from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
@ -14,6 +15,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \
TensorType, ViewRequirementsDict
from ray.util.debug import log_once
@ -47,7 +49,8 @@ class _AgentCollector:
_next_unroll_id = 0 # disambiguates unrolls within a single episode
def __init__(self, view_reqs):
def __init__(self, view_reqs, policy):
self.policy = policy
# Determine the size of the buffer we need for data before the actual
# episode starts. This is used for 0-buffering of e.g. prev-actions,
# or internal state inputs.
@ -57,10 +60,28 @@ class _AgentCollector:
(1
if vr.data_col == SampleBatch.OBS or k == SampleBatch.OBS else 0)
for k, vr in view_reqs.items())
# The actual data buffers (lists holding each timestep's data).
self.buffers: Dict[str, List] = {}
# The actual data buffers. Keys are column names, values are lists
# that contain the sub-components (e.g. for complex obs spaces) with
# each sub-component holding a list of per-timestep tensors.
# E.g.: obs-space = Dict(a=Discrete(2), b=Box((2,)))
# buffers["obs"] = [
# [0, 1], # <- 1st sub-component of observation
# [np.array([.2, .3]), np.array([.0, -.2])] # <- 2nd sub-component
# ]
# NOTE: infos and state_out_... are not flattened due to them often
# using custom dict values whose structure may vary from timestep to
# timestep.
self.buffers: Dict[str, List[List[TensorType]]] = {}
# Maps column names to an example data item, which may be deeply
# nested. These are used such that we'll know how to unflatten
# the flattened data inside self.buffers when building the
# SampleBatch.
self.buffer_structs: Dict[str, Any] = {}
# The episode ID for the agent for which we collect data.
self.episode_id = None
# The unroll ID, unique across all rollouts (within a RolloutWorker).
self.unroll_id = None
# The simple timestep count for this agent. Gets increased by one
# each time a (non-initial!) observation is added.
self.agent_steps = 0
@ -80,6 +101,13 @@ class _AgentCollector:
init_obs (TensorType): The initial observation tensor (after
`env.reset()`).
"""
# Store episode ID + unroll ID, which will be constant throughout this
# AgentCollector's lifecycle.
self.episode_id = episode_id
if self.unroll_id is None:
self.unroll_id = _AgentCollector._next_unroll_id
_AgentCollector._next_unroll_id += 1
if SampleBatch.OBS not in self.buffers:
self._build_buffers(
single_row={
@ -87,12 +115,19 @@ class _AgentCollector:
SampleBatch.AGENT_INDEX: agent_index,
SampleBatch.ENV_ID: env_id,
SampleBatch.T: t,
SampleBatch.EPS_ID: self.episode_id,
SampleBatch.UNROLL_ID: self.unroll_id,
})
self.buffers[SampleBatch.OBS].append(init_obs)
self.episode_id = episode_id
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
self.buffers[SampleBatch.ENV_ID].append(env_id)
self.buffers[SampleBatch.T].append(t)
# Append data to existing buffers.
flattened = tree.flatten(init_obs)
for i, sub_obs in enumerate(flattened):
self.buffers[SampleBatch.OBS][i].append(sub_obs)
self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index)
self.buffers[SampleBatch.ENV_ID][0].append(env_id)
self.buffers[SampleBatch.T][0].append(t)
self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
None:
@ -103,20 +138,40 @@ class _AgentCollector:
row) to be added to buffer. Must contain keys:
SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS.
"""
if self.unroll_id is None:
self.unroll_id = _AgentCollector._next_unroll_id
_AgentCollector._next_unroll_id += 1
# Next obs -> obs.
assert SampleBatch.OBS not in values
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
del values[SampleBatch.NEXT_OBS]
# Make sure EPS_ID stays the same for this agent. Usually, it should
# not be part of `values` anyways.
# Make sure EPS_ID/UNROLL_ID stay the same for this agent.
if SampleBatch.EPS_ID in values:
assert values[SampleBatch.EPS_ID] == self.episode_id
del values[SampleBatch.EPS_ID]
self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
if SampleBatch.UNROLL_ID in values:
assert values[SampleBatch.UNROLL_ID] == self.unroll_id
del values[SampleBatch.UNROLL_ID]
self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
for k, v in values.items():
if k not in self.buffers:
self._build_buffers(single_row=values)
self.buffers[k].append(v)
# Do not flatten infos, state_out_ and actions.
# Infos/state-outs may be structs that change from timestep to
# timestep. Actions - on the other hand - are already flattened
# in the sampler.
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS
] or k.startswith("state_out_"):
self.buffers[k][0].append(v)
# Flatten all other columns.
else:
flattened = tree.flatten(v)
for i, sub_list in enumerate(self.buffers[k]):
sub_list.append(flattened[i])
self.agent_steps += 1
def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
@ -157,7 +212,9 @@ class _AgentCollector:
# Keep an np-array cache so we don't have to regenerate the
# np-array for different view_cols using to the same data_col.
if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col])
np_data[data_col] = [
to_float_np_array(d) for d in self.buffers[data_col]
]
# Range of indices on time-axis, e.g. "-50:-1". Together with
# the `batch_repeat_value`, this determines the data produced.
@ -171,42 +228,50 @@ class _AgentCollector:
# every n timesteps.
if view_req.batch_repeat_value > 1:
count = int(
math.ceil((len(np_data[data_col]) - self.shift_before)
/ view_req.batch_repeat_value))
data = np.asarray([
np_data[data_col][self.shift_before +
(i * view_req.batch_repeat_value) +
view_req.shift_from +
obs_shift:self.shift_before +
(i * view_req.batch_repeat_value) +
view_req.shift_to + 1 + obs_shift]
for i in range(count)
])
math.ceil(
(len(np_data[data_col][0]) - self.shift_before) /
view_req.batch_repeat_value))
data = [
np.asarray([
d[self.shift_before +
(i * view_req.batch_repeat_value) +
view_req.shift_from +
obs_shift:self.shift_before +
(i * view_req.batch_repeat_value) +
view_req.shift_to + 1 + obs_shift]
for i in range(count)
]) for d in np_data[data_col]
]
# Batch repeat value = 1: Repeat the shift_from/to range at
# each timestep.
else:
d = np_data[data_col]
d0 = np_data[data_col][0]
shift_win = view_req.shift_to - view_req.shift_from + 1
data_size = d.itemsize * int(np.product(d.shape[1:]))
data_size = d0.itemsize * int(np.product(d0.shape[1:]))
strides = [
d.itemsize * int(np.product(d.shape[i + 1:]))
for i in range(1, len(d.shape))
d0.itemsize * int(np.product(d0.shape[i + 1:]))
for i in range(1, len(d0.shape))
]
start = self.shift_before - shift_win + 1 + obs_shift + \
view_req.shift_to
data = np.lib.stride_tricks.as_strided(
d[start:start + self.agent_steps],
[self.agent_steps, shift_win
] + [d.shape[i] for i in range(1, len(d.shape))],
[data_size, data_size] + strides)
data = [
np.lib.stride_tricks.as_strided(
d[start:start + self.agent_steps],
[self.agent_steps, shift_win
] + [d.shape[i] for i in range(1, len(d.shape))],
[data_size, data_size] + strides)
for d in np_data[data_col]
]
# Set of (probably non-consecutive) indices.
# Example:
# shift=[-3, 0]
# buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...]
elif isinstance(view_req.shift, np.ndarray):
data = np_data[data_col][self.shift_before + obs_shift +
view_req.shift]
data = [
d[self.shift_before + obs_shift + view_req.shift]
for d in np_data[data_col]
]
# Single shift int value. Use the trajectory as-is, and if
# `shift` != 0: shifted by that value.
else:
@ -215,58 +280,77 @@ class _AgentCollector:
# Batch repeat (only provide a value every n timesteps).
if view_req.batch_repeat_value > 1:
count = int(
math.ceil((len(np_data[data_col]) - self.shift_before)
/ view_req.batch_repeat_value))
data = np.asarray([
np_data[data_col][self.shift_before + (
i * view_req.batch_repeat_value) + shift]
for i in range(count)
])
math.ceil(
(len(np_data[data_col][0]) - self.shift_before) /
view_req.batch_repeat_value))
data = [
np.asarray([
d[self.shift_before +
(i * view_req.batch_repeat_value) + shift]
for i in range(count)
]) for d in np_data[data_col]
]
# Shift is exactly 0: Use trajectory as is.
elif shift == 0:
data = np_data[data_col][self.shift_before:]
data = [d[self.shift_before:] for d in np_data[data_col]]
# Shift is positive: We still need to 0-pad at the end.
elif shift > 0:
data = to_float_np_array(
self.buffers[data_col][self.shift_before + shift:] + [
np.zeros(
shape=view_req.space.shape,
dtype=view_req.space.dtype)
for _ in range(shift)
])
data = [
to_float_np_array(
np.concatenate([
d[self.shift_before + shift:], [
np.zeros(
shape=view_req.space.shape,
dtype=view_req.space.dtype)
for _ in range(shift)
]
])) for d in np_data[data_col]
]
# Shift is negative: Shift into the already existing and
# 0-padded "before" area of our buffers.
else:
data = np_data[data_col][self.shift_before + shift:shift]
data = [
d[self.shift_before + shift:shift]
for d in np_data[data_col]
]
if len(data) > 0:
batch_data[view_col] = data
if data_col not in self.buffer_structs:
batch_data[view_col] = data[0]
else:
batch_data[view_col] = tree.unflatten_as(
self.buffer_structs[data_col], data)
# Due to possible batch-repeats > 1, columns in the resulting batch
# may not all have the same batch size.
batch = SampleBatch(batch_data)
# Add EPS_ID and UNROLL_ID to batch.
batch[SampleBatch.EPS_ID] = np.repeat(self.episode_id, batch.count)
if SampleBatch.UNROLL_ID not in batch:
# TODO: (sven) Once we have the additional
# model.preprocess_train_batch in place (attention net PR), we
# should not even need UNROLL_ID anymore:
# Add "if SampleBatch.UNROLL_ID in view_requirements:" here.
batch[SampleBatch.UNROLL_ID] = np.repeat(
_AgentCollector._next_unroll_id, batch.count)
_AgentCollector._next_unroll_id += 1
# Adjust the seq-lens array depending on the incoming agent sequences.
if self.policy.is_recurrent():
seq_lens = []
max_seq_len = self.policy.config["model"]["max_seq_len"]
count = batch.count
while count > 0:
seq_lens.append(min(count, max_seq_len))
count -= max_seq_len
batch["seq_lens"] = np.array(seq_lens)
batch.max_seq_len = max_seq_len
# This trajectory is continuing -> Copy data at the end (in the size of
# self.shift_before) to the beginning of buffers and erase everything
# else.
if not self.buffers[SampleBatch.DONES][-1]:
if not self.buffers[SampleBatch.DONES][0][-1]:
# Copy data to beginning of buffer and cut lists.
if self.shift_before > 0:
for k, data in self.buffers.items():
self.buffers[k] = data[-self.shift_before:]
# Loop through
for i in range(len(data)):
self.buffers[k][i] = data[i][-self.shift_before:]
self.agent_steps = 0
# Reset our unroll_id.
self.unroll_id = None
return batch
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
@ -279,12 +363,25 @@ class _AgentCollector:
for col, data in single_row.items():
if col in self.buffers:
continue
shift = self.shift_before - (1 if col in [
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
SampleBatch.ENV_ID, SampleBatch.T
SampleBatch.ENV_ID, SampleBatch.T, SampleBatch.UNROLL_ID
] else 0)
# Python primitive, tensor, or dict (e.g. INFOs).
self.buffers[col] = [data for _ in range(shift)]
# Store all data as flattened lists, except INFOS and state-out
# lists. These are monolithic items (infos is a dict that
# should not be further split, same for state-out items, which
# could be custom dicts as well).
if col in [SampleBatch.INFOS, SampleBatch.ACTIONS
] or col.startswith("state_out_"):
self.buffers[col] = [[data for _ in range(shift)]]
else:
self.buffers[col] = [[v for _ in range(shift)]
for v in tree.flatten(data)]
# Store an example data struct so we know, how to unflatten
# each data col.
self.buffer_structs[col] = data
class _PolicyCollector:
@ -302,15 +399,13 @@ class _PolicyCollector:
policy (Policy): The policy object.
"""
self.buffers: Dict[str, List] = collections.defaultdict(list)
self.batches = []
self.policy = policy
# The total timestep count for all agents that use this policy.
# NOTE: This is not an env-step count (across n agents). AgentA and
# agentB, both using this policy, acting in the same episode and both
# doing n steps would increase the count by 2*n.
self.agent_steps = 0
# Seq-lens list of already added agent batches.
self.seq_lens = [] if policy.is_recurrent() else None
def add_postprocessed_batch_for_training(
self, batch: SampleBatch,
@ -325,22 +420,13 @@ class _PolicyCollector:
view-column needs to be copied at all (not needed for
training).
"""
for view_col, data in batch.items():
# 1) If col is not in view_requirements, we must have a direct
# child of the base Policy that doesn't do auto-view req creation.
# 2) Col is in view-reqs and needed for training.
view_req = view_requirements.get(view_col)
if view_req is None or view_req.used_for_training:
self.buffers[view_col].extend(data)
# Add the agent's trajectory length to our count.
self.agent_steps += batch.count
# Adjust the seq-lens array depending on the incoming agent sequences.
if self.seq_lens is not None:
max_seq_len = self.policy.config["model"]["max_seq_len"]
count = batch.count
while count > 0:
self.seq_lens.append(min(count, max_seq_len))
count -= max_seq_len
# And remove columns not needed for training.
for view_col, view_req in view_requirements.items():
if view_col in batch and not view_req.used_for_training:
del batch[view_col]
self.batches.append(batch)
def build(self):
"""Builds a SampleBatch for this policy from the collected data.
@ -352,13 +438,11 @@ class _PolicyCollector:
this policy.
"""
# Create batch from our buffers.
batch = SampleBatch(self.buffers, seq_lens=self.seq_lens)
# Clear buffers for future samples.
self.buffers.clear()
batch = SampleBatch.concat_samples(self.batches)
# Clear batches for future samples.
self.batches = []
# Reset agent steps to 0 and seq-lens to empty list.
self.agent_steps = 0
if self.seq_lens is not None:
self.seq_lens = []
return batch
@ -479,7 +563,7 @@ class SimpleListCollector(SampleCollector):
# Add initial obs to Trajectory.
assert agent_key not in self.agent_collectors
# TODO: determine exact shift-before based on the view-req shifts.
self.agent_collectors[agent_key] = _AgentCollector(view_reqs)
self.agent_collectors[agent_key] = _AgentCollector(view_reqs, policy)
self.agent_collectors[agent_key].add_init_obs(
episode_id=episode.episode_id,
agent_index=episode._agent_index(agent_id),
@ -537,7 +621,13 @@ class SimpleListCollector(SampleCollector):
Dict[str, TensorType]:
policy = self.policy_map[policy_id]
keys = self.forward_pass_agent_keys[policy_id]
buffers = {k: self.agent_collectors[k].buffers for k in keys}
buffers = {}
for k in keys:
collector = self.agent_collectors[k]
buffers[k] = collector.buffers
# Use one agent's buffer_structs (they should all be the same).
buffer_structs = self.agent_collectors[keys[0]].buffer_structs
input_dict = {}
for view_col, view_req in policy.view_requirements.items():
@ -558,33 +648,57 @@ class SimpleListCollector(SampleCollector):
# Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0].
else:
time_indices = view_req.shift + delta
data_list = []
# Loop through agents and add-up their data (batch).
# Loop through agents and add up their data (batch).
data = None
for k in keys:
if data_col == SampleBatch.EPS_ID:
data_list.append(self.agent_collectors[k].episode_id)
else:
if data_col not in buffers[k]:
if view_req.data_col is not None:
space = policy.view_requirements[
view_req.data_col].space
else:
space = view_req.space
fill_value = np.zeros_like(space.sample()) \
if isinstance(space, Space) else space
self.agent_collectors[k]._build_buffers({
data_col: fill_value
})
if isinstance(time_indices, tuple):
if time_indices[1] == -1:
data_list.append(
buffers[k][data_col][time_indices[0]:])
else:
data_list.append(buffers[k][data_col][time_indices[
0]:time_indices[1] + 1])
# Buffer for the data does not exist yet: Create dummy
# (zero) data.
if data_col not in buffers[k]:
if view_req.data_col is not None:
space = policy.view_requirements[
view_req.data_col].space
else:
data_list.append(buffers[k][data_col][time_indices])
input_dict[view_col] = np.array(data_list)
space = view_req.space
if isinstance(space, Space):
fill_value = get_dummy_batch_for_space(
space,
batch_size=0,
)
else:
fill_value = space
self.agent_collectors[k]._build_buffers({
data_col: fill_value
})
if data is None:
data = [[] for _ in range(len(buffers[keys[0]][data_col]))]
# `shift_from` and `shift_to` are defined: User wants a
# view with some time-range.
if isinstance(time_indices, tuple):
# `shift_to` == -1: Until the end (including(!) the
# last item).
if time_indices[1] == -1:
for d, b in zip(data, buffers[k][data_col]):
d.append(b[time_indices[0]:])
# `shift_to` != -1: "Normal" range.
else:
for d, b in zip(data, buffers[k][data_col]):
d.append(b[time_indices[0]:time_indices[1] + 1])
# Single index.
else:
for d, b in zip(data, buffers[k][data_col]):
d.append(b[time_indices])
np_data = [np.array(d) for d in data]
if data_col in buffer_structs:
input_dict[view_col] = tree.unflatten_as(
buffer_structs[data_col], np_data)
else:
input_dict[view_col] = np_data[0]
self._reset_inference_calls(policy_id)

View file

@ -20,7 +20,7 @@ from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import NoPreprocessor, Preprocessor
from ray.rllib.models.preprocessors import Preprocessor
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
OffPolicyEstimate
@ -44,7 +44,7 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \
ModelConfigDict, ModelGradients, ModelWeights, \
MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \
SampleBatchType, TrainerConfigDict
SampleBatchType
from ray.util.debug import log_once, disable_log_once_globally, \
enable_periodic_logging
from ray.util.iter import ParallelIteratorWorker
@ -168,7 +168,8 @@ class RolloutWorker(ParallelIteratorWorker):
env_creator: Callable[[EnvContext], EnvType],
validate_env: Optional[Callable[[EnvType, EnvContext],
None]] = None,
policy_spec: Union[type, Dict[PolicyID, PolicySpec]] = None,
policy_spec: Optional[Union[type, Dict[PolicyID,
PolicySpec]]] = None,
policy_mapping_fn: Optional[Callable[
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
@ -176,24 +177,24 @@ class RolloutWorker(ParallelIteratorWorker):
rollout_fragment_length: int = 100,
count_steps_by: str = "env_steps",
batch_mode: str = "truncate_episodes",
episode_horizon: int = None,
preprocessor_pref: Optional[str] = "deepmind",
episode_horizon: Optional[int] = None,
preprocessor_pref: str = "deepmind",
sample_async: bool = False,
compress_observations: bool = False,
num_envs: int = 1,
observation_fn: "ObservationFunction" = None,
observation_fn: Optional["ObservationFunction"] = None,
observation_filter: str = "NoFilter",
clip_rewards: Optional[Union[bool, float]] = None,
normalize_actions: bool = True,
clip_actions: bool = False,
env_config: EnvConfigDict = None,
model_config: ModelConfigDict = None,
policy_config: TrainerConfigDict = None,
env_config: Optional[EnvConfigDict] = None,
model_config: Optional[ModelConfigDict] = None,
policy_config: Optional[PartialTrainerConfigDict] = None,
worker_index: int = 0,
num_workers: int = 0,
record_env: Union[bool, str] = False,
log_dir: str = None,
log_level: str = None,
log_dir: Optional[str] = None,
log_level: Optional[str] = None,
callbacks: Type["DefaultCallbacks"] = None,
input_creator: Callable[[
IOContext
@ -206,7 +207,7 @@ class RolloutWorker(ParallelIteratorWorker):
soft_horizon: bool = False,
no_done_at_end: bool = False,
seed: int = None,
extra_python_environs: dict = None,
extra_python_environs: Optional[dict] = None,
fake_sampler: bool = False,
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]] = None,
@ -258,10 +259,10 @@ class RolloutWorker(ParallelIteratorWorker):
that when `num_envs > 1`, episode steps will be buffered
until the episode completes, and hence batches may contain
significant amounts of off-policy data.
episode_horizon (int): Whether to stop episodes at this horizon.
preprocessor_pref (Optional[str]): Whether to use no preprocessor
(None), RLlib preprocessors ("rllib") or deepmind ("deepmind"),
when applicable.
episode_horizon: Horizon at which to stop episodes (even if the
environment itself has not retured a "done" signal).
preprocessor_pref (str): Whether to use RLlib preprocessors
("rllib") or deepmind ("deepmind"), when applicable.
sample_async (bool): Whether to compute samples asynchronously in
the background, which improves throughput but can cause samples
to be slightly off-policy.
@ -284,9 +285,9 @@ class RolloutWorker(ParallelIteratorWorker):
env_config (EnvConfigDict): Config to pass to the env creator.
model_config (ModelConfigDict): Config to use when creating the
policy model.
policy_config (TrainerConfigDict): Config to pass to the policy.
In the multi-agent case, this config will be merged with the
per-policy configs specified by `policy_spec`.
policy_config: Config to pass to the
policy. In the multi-agent case, this config will be merged
with the per-policy configs specified by `policy_spec`.
worker_index (int): For remote workers, this should be set to a
non-zero and unique value. This index is passed to created envs
through EnvContext so that envs can be configured per worker.
@ -378,7 +379,7 @@ class RolloutWorker(ParallelIteratorWorker):
ParallelIteratorWorker.__init__(self, gen_rollouts, False)
policy_config: TrainerConfigDict = policy_config or {}
policy_config = policy_config or {}
if (tf1 and policy_config.get("framework") in ["tf2", "tfe"]
# This eager check is necessary for certain all-framework tests
# that use tf's eager_mode() context generator.
@ -400,7 +401,7 @@ class RolloutWorker(ParallelIteratorWorker):
num_workers=num_workers,
)
self.env_context = env_context
self.policy_config: TrainerConfigDict = policy_config
self.policy_config: PartialTrainerConfigDict = policy_config
if callbacks:
self.callbacks: "DefaultCallbacks" = callbacks()
else:
@ -424,10 +425,10 @@ class RolloutWorker(ParallelIteratorWorker):
self.batch_mode: str = batch_mode
self.compress_observations: bool = compress_observations
self.preprocessing_enabled: bool = False \
if preprocessor_pref is None else True
if policy_config.get("_disable_preprocessor_api") else True
self.observation_filter = observation_filter
self.last_batch: SampleBatchType = None
self.global_vars: dict = None
self.last_batch: Optional[SampleBatchType] = None
self.global_vars: Optional[dict] = None
self.fake_sampler: bool = fake_sampler
# Update the global seed for numpy/random/tf-eager/torch if we are not
@ -1076,10 +1077,9 @@ class RolloutWorker(ParallelIteratorWorker):
space of the policy to add.
action_space (Optional[gym.spaces.Space]): The action space
of the policy to add.
config (Optional[PartialTrainerConfigDict]): The config
overrides for the policy to add.
policy_config (Optional[TrainerConfigDict]): The base config of the
Trainer object owning this RolloutWorker.
config: The config overrides for the policy to add.
policy_config: The base config of the Trainer object owning this
RolloutWorker.
policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode],
PolicyID]]): An optional (updated) policy mapping function to
use from here on. Note that already ongoing episodes will not
@ -1340,7 +1340,7 @@ class RolloutWorker(ParallelIteratorWorker):
def _build_policy_map(
self,
policy_dict: MultiAgentPolicyConfigDict,
policy_config: TrainerConfigDict,
policy_config: PartialTrainerConfigDict,
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
seed: Optional[int] = None,
) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]:
@ -1371,13 +1371,7 @@ class RolloutWorker(ParallelIteratorWorker):
if preprocessor is not None:
obs_space = preprocessor.observation_space
else:
self.preprocessors[name] = NoPreprocessor(obs_space)
if isinstance(obs_space, (gym.spaces.Dict, gym.spaces.Tuple)):
raise ValueError(
"Found raw Tuple|Dict space as input to policy. "
"Please preprocess these observations with a "
"Tuple|DictFlatteningPreprocessor.")
self.preprocessors[name] = None
self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
conf, merged_conf)

View file

@ -275,14 +275,14 @@ class AsyncSampler(threading.Thread, SamplerInput):
rollout_fragment_length: int,
count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks",
horizon: int = None,
horizon: Optional[int] = None,
multiple_episodes_in_batch: bool = False,
normalize_actions: bool = True,
clip_actions: bool = False,
blackhole_outputs: bool = False,
soft_horizon: bool = False,
no_done_at_end: bool = False,
observation_fn: "ObservationFunction" = None,
observation_fn: Optional["ObservationFunction"] = None,
sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False,
# Obsolete.
@ -308,7 +308,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
Refers to the unit of `rollout_fragment_length`.
callbacks (Callbacks): The Callbacks object to use when episode
events happen during rollout.
horizon (Optional[int]): Hard-reset the Env
horizon: Hard-reset the Env after this many timesteps.
multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
@ -452,7 +452,7 @@ def _env_runner(
worker: "RolloutWorker",
base_env: BaseEnv,
extra_batch_callback: Callable[[SampleBatchType], None],
horizon: int,
horizon: Optional[int],
normalize_actions: bool,
clip_actions: bool,
multiple_episodes_in_batch: bool,
@ -470,7 +470,7 @@ def _env_runner(
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to.
horizon (int): Horizon of the episode.
horizon: Horizon of the episode.
multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.

View file

@ -117,21 +117,24 @@ class TestRolloutWorker(unittest.TestCase):
ev.stop()
def test_batch_ids(self):
fragment_len = 100
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_spec=MockPolicy,
rollout_fragment_length=1)
rollout_fragment_length=fragment_len)
batch1 = ev.sample()
batch2 = ev.sample()
self.assertEqual(len(set(batch1["unroll_id"])), 1)
self.assertEqual(len(set(batch2["unroll_id"])), 1)
self.assertEqual(
len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
unroll_ids_1 = set(batch1["unroll_id"])
unroll_ids_2 = set(batch2["unroll_id"])
# Assert no overlap of unroll IDs between sample() calls.
self.assertTrue(not any(uid in unroll_ids_2 for uid in unroll_ids_1))
# CartPole episodes should be short initially: Expect more than one
# unroll ID in each batch.
self.assertTrue(len(unroll_ids_1) > 1)
self.assertTrue(len(unroll_ids_2) > 1)
ev.stop()
def test_global_vars_update(self):
# Allow for Unittest run.
ray.init(num_cpus=5, ignore_reinit_error=True)
for fw in framework_iterator(frameworks=("tf2", "tf")):
agent = A2CTrainer(
env="CartPole-v0",

View file

@ -185,31 +185,49 @@ class TestTrajectoryViewAPI(unittest.TestCase):
policy_mapping_fn=None,
num_envs=1,
)
# Add the next action to the view reqs of the policy.
# Add the next action (a') and 2nd next action (a'') to the view
# requirements of the policy.
# This should be visible then in postprocessing and train batches.
# Switch off for action computations (can't be there as we don't know
# the next action already at action computation time).
# the next actions already at action computation time).
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
"next_actions"] = ViewRequirement(
SampleBatch.ACTIONS,
shift=1,
space=action_space,
used_for_compute_actions=False)
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
"2nd_next_actions"] = ViewRequirement(
SampleBatch.ACTIONS,
shift=2,
space=action_space,
used_for_compute_actions=False)
# Make sure, we have DONEs as well.
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
"dones"] = ViewRequirement()
batch = rollout_worker_w_api.sample()
self.assertTrue("next_actions" in batch)
self.assertTrue("2nd_next_actions" in batch)
expected_a_ = None # expected next action
expected_a__ = None # expected 2nd next action
for i in range(len(batch["actions"])):
a, d, a_ = batch["actions"][i], batch["dones"][i], \
batch["next_actions"][i]
if not d and expected_a_ is not None:
check(a, expected_a_)
elif d:
a, d, a_, a__ = \
batch["actions"][i], batch["dones"][i], \
batch["next_actions"][i], batch["2nd_next_actions"][i]
# Episode done: next action and 2nd next action should be 0.
if d:
check(a_, 0)
check(a__, 0)
expected_a_ = None
expected_a__ = None
continue
# Episode is not done and we have an expected next-a.
if expected_a_ is not None:
check(a, expected_a_)
if expected_a__ is not None:
check(a_, expected_a__)
expected_a__ = a__
expected_a_ = a_
def test_traj_view_lstm_functionality(self):

View file

@ -57,7 +57,7 @@ class MyCallbacks(DefaultCallbacks):
env_index: int, **kwargs):
# Make sure this episode is really done.
assert episode.batch_builder.policy_collectors[
"default_policy"].buffers["dones"][-1], \
"default_policy"].batches[-1]["dones"][-1], \
"ERROR: `on_episode_end()` should only be called " \
"after episode is done!"
pole_angle = np.mean(episode.user_data["pole_angles"])

View file

@ -49,6 +49,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
},
**self.model.view_requirements)

View file

@ -0,0 +1,107 @@
"""
Example for using _disable_preprocessor_api=True to disable all preprocessing.
This example shows:
- How a complex observation space from the env is handled directly by the
model.
- Complex observations are flattened into lists of tensors and as such
stored by the SampleCollectors.
- This has the advantage that preprocessing happens in batched fashion
(in the model).
"""
import argparse
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
import os
import ray
from ray import tune
def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()
# general args
parser.add_argument(
"--run", default="PPO", help="The RLlib-registered algorithm to use.")
parser.add_argument("--num-cpus", type=int, default=3)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument(
"--stop-iters",
type=int,
default=200,
help="Number of iterations to train.")
parser.add_argument(
"--stop-timesteps",
type=int,
default=500000,
help="Number of timesteps to train.")
parser.add_argument(
"--stop-reward",
type=float,
default=80.0,
help="Reward at which we stop training.")
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.")
parser.add_argument(
"--no-tune",
action="store_true",
help="Run without Tune using a manual train loop instead. Here,"
"there is no TensorBoard support.")
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.")
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args
if __name__ == "__main__":
args = get_cli_args()
ray.init(local_mode=args.local_mode)
config = {
"env": "ray.rllib.examples.env.random_env.RandomEnv",
"env_config": {
"config": {
"observation_space": Dict({
"a": Discrete(2),
"b": Dict({
"ba": Discrete(3),
"bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32)
}),
"c": Tuple((MultiDiscrete([2, 3]), Discrete(2))),
"d": Box(-1.0, 1.0, (2, ), dtype=np.int32),
}),
},
},
# Set this to True to enforce no preprocessors being used.
# Complex observations now arrive directly in the model as
# structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]}
# for obs-space=Dict(a=..., b=Tuple(..., ...)).
"_disable_preprocessor_api": True,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", 0)),
"framework": args.framework,
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, config=config, stop=stop, verbose=2)
ray.shutdown()

View file

@ -502,7 +502,8 @@ class LocalReplayBuffer(ParallelIteratorWorker):
# If SampleBatch has prio-replay weights, average
# over these to use as a weight for the entire
# sequence.
if "weights" in time_slice:
if "weights" in time_slice and \
len(time_slice["weights"]):
weight = np.mean(time_slice["weights"])
else:
weight = None

View file

@ -46,9 +46,9 @@ MODEL_DEFAULTS: ModelConfigDict = {
"_use_default_native_models": False,
# Experimental flag.
# If True, user specified no preprocessor to be created
# (via config.preprocessor_pref=None). If True, observations will arrive
# in model as they are returned by the env.
"_no_preprocessing": False,
# (via config._disable_preprocessor_api=True). If True, observations
# will arrive in model as they are returned by the env.
"_disable_preprocessor_api": False,
# === Built-in options ===
# FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py

View file

@ -217,12 +217,12 @@ class ModelV2:
else:
restored = input_dict.copy()
# No Preprocessor used: `config.preprocessor_pref`=None.
# No Preprocessor used: `config._disable_preprocessor_api`=True.
# TODO: This is unnecessary for when no preprocessor is used.
# Obs are not flat then anymore. However, we'll keep this
# here for backward-compatibility until Preprocessors have
# been fully deprecated.
if self.model_config.get("_no_preprocessing"):
if self.model_config.get("_disable_preprocessor_api"):
restored["obs_flat"] = input_dict["obs"]
# Input to this Model went through a Preprocessor.
# Generate extra keys: "obs_flat" (vs "obs", which will hold the

View file

@ -1,7 +1,7 @@
from typing import List
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.framework import TensorType, TensorStructType
from ray.rllib.utils.typing import TensorType, TensorStructType
@PublicAPI

View file

@ -3,14 +3,57 @@ from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
import unittest
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.preprocessors import DictFlatteningPreprocessor, \
get_preprocessor, NoPreprocessor, TupleFlatteningPreprocessor, \
OneHotPreprocessor, AtariRamPreprocessor, GenericPixelPreprocessor
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
framework_iterator
class TestPreprocessors(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_preprocessing_disabled(self):
config = ppo.DEFAULT_CONFIG.copy()
config["env"] = "ray.rllib.examples.env.random_env.RandomEnv"
config["env_config"] = {
"config": {
"observation_space": Dict({
"a": Discrete(5),
"b": Dict({
"ba": Discrete(4),
"bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32)
}),
"c": Tuple((MultiDiscrete([2, 3]), Discrete(1))),
"d": Box(-1.0, 1.0, (1, ), dtype=np.int32),
}),
},
}
# Set this to True to enforce no preprocessors being used.
# Complex observations now arrive directly in the model as
# structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]}
# for obs-space=Dict(a=..., b=Tuple(..., ...)).
config["_disable_preprocessor_api"] = True
num_iterations = 1
# Only supported for tf so far.
for _ in framework_iterator(config):
trainer = ppo.PPOTrainer(config=config)
for i in range(num_iterations):
print(trainer.train())
check_compute_single_action(trainer)
trainer.stop()
def test_gym_preprocessors(self):
p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0"))
self.assertEqual(type(p1), NoPreprocessor)

View file

@ -133,7 +133,7 @@ class ComplexInputNetwork(TFModelV2):
cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
outs.append(cnn_out)
elif i in self.one_hot:
if component.dtype in [tf.int32, tf.int64, tf.uint8]:
if "int" in component.dtype.name:
outs.append(
one_hot(component, self.flattened_input_space[i]))
else:

View file

@ -546,6 +546,7 @@ class DynamicTFPolicy(TFPolicy):
# Skip action dist inputs placeholder (do later).
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
continue
# This is a tower, input placeholders already exist.
elif view_col in existing_inputs:
input_dict[view_col] = existing_inputs[view_col]
# All others.
@ -554,10 +555,15 @@ class DynamicTFPolicy(TFPolicy):
if view_req.used_for_training:
# Create a +time-axis placeholder if the shift is not an
# int (range or list of ints).
flatten = view_col not in [
SampleBatch.OBS, SampleBatch.NEXT_OBS] or \
not self.config["_disable_preprocessor_api"]
input_dict[view_col] = get_placeholder(
space=view_req.space,
name=view_col,
time_axis=time_axis)
time_axis=time_axis,
flatten=flatten,
)
dummy_batch = self._get_dummy_batch_from_view_requirements(
batch_size=32)

View file

@ -5,6 +5,7 @@ It supports both traced and non-traced eager execution modes."""
import functools
import logging
import threading
import tree # pip install dm_tree
from typing import Dict, List, Optional, Tuple
from ray.util.debug import log_once
@ -425,10 +426,12 @@ def build_eager_tf_policy(
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
input_dict = {
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
"is_training": tf.constant(False),
}
input_dict = SampleBatch(
{
SampleBatch.CUR_OBS: tree.map_structure(
lambda s: tf.convert_to_tensor(s), obs_batch),
},
_is_training=tf.constant(False))
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \
tf.convert_to_tensor(prev_action_batch)
@ -478,7 +481,7 @@ def build_eager_tf_policy(
self._is_training = False
self._state_in = state_batches or []
# Calculate RNN sequence lengths.
batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \
else None

View file

@ -4,6 +4,7 @@ import gym
from gym.spaces import Box
import logging
import numpy as np
import tree # pip install dm_tree
from typing import Dict, List, Optional, TYPE_CHECKING
from ray.rllib.models.catalog import ModelCatalog
@ -17,7 +18,7 @@ from ray.rllib.utils.spaces.space_utils import clip_action, \
get_base_struct_from_space, get_dummy_batch_for_space, unbatch, \
unsquash_action
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
TensorType, TrainerConfigDict, Tuple, Union
TensorType, TensorStructType, TrainerConfigDict, Tuple, Union
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -132,10 +133,12 @@ class Policy(metaclass=ABCMeta):
@DeveloperAPI
def compute_actions(
self,
obs_batch: Union[List[TensorType], TensorType],
obs_batch: Union[List[TensorStructType], TensorStructType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Union[List[TensorType], TensorType] = None,
prev_reward_batch: Union[List[TensorType], TensorType] = None,
prev_action_batch: Union[List[TensorStructType],
TensorStructType] = None,
prev_reward_batch: Union[List[TensorStructType],
TensorStructType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
explore: Optional[bool] = None,
@ -145,14 +148,10 @@ class Policy(metaclass=ABCMeta):
"""Computes actions for the current policy.
Args:
obs_batch (Union[List[TensorType], TensorType]): Batch of
observations.
state_batches (Optional[List[TensorType]]): List of RNN state input
batches, if any.
prev_action_batch (Union[List[TensorType], TensorType]): Batch of
previous action values.
prev_reward_batch (Union[List[TensorType], TensorType]): Batch of
previous rewards.
obs_batch: Batch of observations.
state_batches: List of RNN state input batches, if any.
prev_action_batch: Batch of previous action values.
prev_reward_batch: Batch of previous rewards.
info_batch (Optional[Dict[str, list]]): Batch of info objects.
episodes (Optional[List[MultiAgentEpisode]] ): List of
MultiAgentEpisode, one for each obs in obs_batch. This provides
@ -181,10 +180,10 @@ class Policy(metaclass=ABCMeta):
@DeveloperAPI
def compute_single_action(
self,
obs: TensorType,
obs: TensorStructType,
state: Optional[List[TensorType]] = None,
prev_action: Optional[TensorType] = None,
prev_reward: Optional[TensorType] = None,
prev_action: Optional[TensorStructType] = None,
prev_reward: Optional[TensorStructType] = None,
info: dict = None,
episode: Optional["MultiAgentEpisode"] = None,
clip_actions: bool = None,
@ -192,37 +191,34 @@ class Policy(metaclass=ABCMeta):
timestep: Optional[int] = None,
unsquash_actions: bool = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
"""Unbatched version of compute_actions.
Args:
obs (TensorType): Single observation.
state (Optional[List[TensorType]]): List of RNN state inputs, if
any.
prev_action (Optional[TensorType]): Previous action value, if any.
prev_reward (Optional[TensorType]): Previous reward, if any.
obs: Single observation.
state: List of RNN state inputs, if any.
prev_action: Previous action value, if any.
prev_reward: Previous reward, if any.
info (dict): Info object, if any.
episode (Optional[MultiAgentEpisode]): this provides access to all
episode: this provides access to all
of the internal episode state, which may be useful for
model-based or multi-agent algorithms.
unsquash_actions (bool): Should actions be unsquashed according to
unsquash_actions: Should actions be unsquashed according to
the Policy's action space?
clip_actions (bool): Should actions be clipped according to the
clip_actions: Should actions be clipped according to the
Policy's action space?
explore (Optional[bool]): Whether to pick an exploitation or
explore: Whether to pick an exploitation or
exploration action
(default: None -> use self.config["explore"]).
timestep (Optional[int]): The current (sampling) time step.
timestep: The current (sampling) time step.
Keyword Args:
kwargs: Forward compatibility.
Returns:
Tuple:
- actions (TensorType): Single action.
- state_outs (List[TensorType]): List of RNN state outputs,
if any.
- info (dict): Dictionary of extra features, if any.
- actions: Single action.
- state_outs: List of RNN state outputs, if any.
- info: Dictionary of extra features, if any.
"""
# If policy works in normalized space, we should unsquash the action.
# Use value of config.normalize_actions, if None.
@ -253,7 +249,7 @@ class Policy(metaclass=ABCMeta):
]
out = self.compute_actions(
[obs],
tree.map_structure(lambda s: np.array([s]), obs),
state_batch,
prev_action_batch=prev_action_batch,
prev_reward_batch=prev_reward_batch,
@ -805,7 +801,8 @@ class Policy(metaclass=ABCMeta):
for key in all_accessed_keys:
if key not in self.view_requirements and \
key != SampleBatch.SEQ_LENS:
self.view_requirements[key] = ViewRequirement()
self.view_requirements[key] = ViewRequirement(
used_for_compute_actions=False)
if self._loss:
# Tag those only needed for post-processing (with some
# exceptions).
@ -852,7 +849,7 @@ class Policy(metaclass=ABCMeta):
"""
ret = {}
for view_col, view_req in self.view_requirements.items():
if self.config["preprocessor_pref"] is not None and \
if not self.config["_disable_preprocessor_api"] and \
isinstance(view_req.space,
(gym.spaces.Dict, gym.spaces.Tuple)):
_, shape = ModelCatalog.get_action_shape(

View file

@ -364,6 +364,9 @@ class SampleBatch(dict):
for i, p in enumerate(path):
if i == len(path) - 1:
curr[p] = value[permutation]
# Translate into list (tuples are immutable).
if isinstance(curr[p], tuple):
curr[p] = list(curr[p])
curr = curr[p]
tree.map_structure_with_path(_permutate_in_place, self)

View file

@ -4,6 +4,7 @@ import logging
import math
import numpy as np
import os
import tree # pip install dm_tree
from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
import ray
@ -435,15 +436,16 @@ class TFPolicy(Policy):
fetched = builder.get(to_fetch)
# Update our global timestep by the batch size.
self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
else obs_batch.shape[0]
self.global_timestep += \
len(obs_batch) if isinstance(obs_batch, list) \
else tree.flatten(obs_batch)[0].shape[0]
return fetched
@override(Policy)
def compute_actions_from_input_dict(
self,
input_dict: Dict[str, TensorType],
input_dict: Union[SampleBatch, Dict[str, TensorType]],
explore: bool = None,
timestep: Optional[int] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
@ -464,6 +466,7 @@ class TFPolicy(Policy):
# Update our global timestep by the batch size.
self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
else len(input_dict) if isinstance(input_dict, SampleBatch) \
else obs_batch.shape[0]
return fetched
@ -965,7 +968,11 @@ class TFPolicy(Policy):
if hasattr(self, "_input_dict"):
for key, value in input_dict.items():
if key in self._input_dict:
builder.add_feed_dict({self._input_dict[key]: value})
# Handle complex/nested spaces as well.
tree.map_structure(
lambda k, v: builder.add_feed_dict({k: v}),
self._input_dict[key], value,
)
# For policies that inherit directly from TFPolicy.
else:
builder.add_feed_dict({
@ -1004,7 +1011,10 @@ class TFPolicy(Policy):
"Must pass in RNN state batches for placeholders {}, "
"got {}".format(self._state_inputs, state_batches))
builder.add_feed_dict({self._obs_input: obs_batch})
tree.map_structure(
lambda k, v: builder.add_feed_dict({k: v}),
self._obs_input, obs_batch,
)
if state_batches:
builder.add_feed_dict({
self._seq_lens: np.ones(len(obs_batch))
@ -1106,8 +1116,12 @@ class TFPolicy(Policy):
# Build the feed dict from the batch.
feed_dict = {}
for key, placeholder in self._loss_input_dict.items():
feed_dict[placeholder] = train_batch[key]
for key, placeholders in self._loss_input_dict.items():
tree.map_structure(
lambda ph, v: feed_dict.__setitem__(ph, v),
placeholders,
train_batch[key],
)
state_keys = [
"state_in_{}".format(i) for i in range(len(self._state_inputs))

View file

@ -26,7 +26,7 @@ from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor
from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \
TrainerConfigDict
TensorStructType, TrainerConfigDict
if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode # noqa
@ -246,23 +246,24 @@ class TorchPolicy(Policy):
@DeveloperAPI
def compute_actions(
self,
obs_batch: Union[List[TensorType], TensorType],
obs_batch: Union[List[TensorStructType], TensorStructType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Union[List[TensorType], TensorType] = None,
prev_reward_batch: Union[List[TensorType], TensorType] = None,
prev_action_batch: Union[List[TensorStructType],
TensorStructType] = None,
prev_reward_batch: Union[List[TensorStructType],
TensorStructType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
with torch.no_grad():
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
input_dict = self._lazy_tensor_dict(
SampleBatch({
SampleBatch.CUR_OBS: np.asarray(obs_batch),
}))
input_dict = self._lazy_tensor_dict({
SampleBatch.CUR_OBS: obs_batch
})
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \
np.asarray(prev_action_batch)
@ -405,13 +406,13 @@ class TorchPolicy(Policy):
@DeveloperAPI
def compute_log_likelihoods(
self,
actions: Union[List[TensorType], TensorType],
obs_batch: Union[List[TensorType], TensorType],
actions: Union[List[TensorStructType], TensorStructType],
obs_batch: Union[List[TensorStructType], TensorStructType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorStructType],
TensorStructType]] = None,
prev_reward_batch: Optional[Union[List[TensorStructType],
TensorStructType]] = None,
actions_normalized: bool = True,
) -> TensorType:

View file

@ -35,10 +35,11 @@ class Filter:
class NoFilter(Filter):
is_concurrent = True
def __init__(self, *args):
pass
def __call__(self, x, update=True):
# Process no further if already np.ndarray, dict, or tuple.
if isinstance(x, (np.ndarray, dict, tuple)):
return x
try:
return np.asarray(x)
except Exception:

View file

@ -5,16 +5,10 @@ import sys
from typing import Any, Optional
from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.typing import TensorStructType, TensorShape, TensorType
from ray.rllib.utils.typing import TensorShape, TensorType
logger = logging.getLogger(__name__)
# Represents a generic tensor type.
TensorType = TensorType
# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
TensorStructType = TensorStructType
def try_import_jax(error=False):
"""Tries importing JAX and FLAX and returns both modules (or Nones).
@ -199,36 +193,36 @@ def _torch_stubs():
return None, nn
def get_variable(value,
def get_variable(value: Any,
framework: str = "tf",
trainable: bool = False,
tf_name: str = "unnamed-variable",
torch_tensor: bool = False,
device: Optional[str] = None,
shape: Optional[TensorShape] = None,
dtype: Optional[Any] = None):
dtype: Optional[TensorType] = None) -> Any:
"""
Args:
value (any): The initial value to use. In the non-tf case, this will
value: The initial value to use. In the non-tf case, this will
be returned as is. In the tf case, this could be a tf-Initializer
object.
framework (str): One of "tf", "torch", or None.
trainable (bool): Whether the generated variable should be
framework: One of "tf", "torch", or None.
trainable: Whether the generated variable should be
trainable (tf)/require_grad (torch) or not (default: False).
tf_name (str): For framework="tf": An optional name for the
tf_name: For framework="tf": An optional name for the
tf.Variable.
torch_tensor (bool): For framework="torch": Whether to actually create
torch_tensor: For framework="torch": Whether to actually create
a torch.tensor, or just a python value (default).
device (Optional[torch.Device]): An optional torch device to use for
device: An optional torch device to use for
the created torch tensor.
shape (Optional[TensorShape]): An optional shape to use iff `value`
shape: An optional shape to use iff `value`
does not have any (e.g. if it's an initializer w/o explicit value).
dtype (Optional[TensorType]): An optional dtype to use iff `value` does
dtype: An optional dtype to use iff `value` does
not have any (e.g. if it's an initializer w/o explicit value).
This should always be a numpy dtype (e.g. np.float32, np.int64).
Returns:
any: A framework-specific variable (tf.Variable, torch.tensor, or
A framework-specific variable (tf.Variable, torch.tensor, or
python primitive).
"""
if framework in ["tf2", "tf", "tfe"]:

View file

@ -4,6 +4,7 @@ import numpy as np
import tree # pip install dm_tree
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import TensorStructType, TensorType
tf1, tf, tfv = try_import_tf()
@ -56,12 +57,26 @@ def get_gpu_devices():
return [d.name for d in devices if "GPU" in d.device_type]
def get_placeholder(*, space=None, value=None, name=None, time_axis=False):
def get_placeholder(*,
space=None,
value=None,
name=None,
time_axis=False,
flatten=True):
from ray.rllib.models.catalog import ModelCatalog
if space is not None:
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
return ModelCatalog.get_action_placeholder(space, None)
if flatten:
return ModelCatalog.get_action_placeholder(space, None)
else:
return tree.map_structure_with_path(
lambda path, component: get_placeholder(
space=component,
name=name + "." + ".".join([str(p) for p in path]),
),
get_base_struct_from_space(space),
)
return tf1.placeholder(
shape=(None, ) + ((None, ) if time_axis else ()) + space.shape,
dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
@ -138,10 +153,13 @@ def zero_logps_from_actions(actions: TensorStructType) -> TensorType:
def one_hot(x, space):
if isinstance(space, Discrete):
return tf.one_hot(x, space.n)
return tf.one_hot(x, space.n, dtype=tf.float32)
elif isinstance(space, MultiDiscrete):
return tf.concat(
[tf.one_hot(x[:, i], n) for i, n in enumerate(space.nvec)],
[
tf.one_hot(x[:, i], n, dtype=tf.float32)
for i, n in enumerate(space.nvec)
],
axis=-1)
else:
raise ValueError("Unsupported space for `one_hot`: {}".format(space))
@ -200,11 +218,12 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
def make_wrapper(fn):
# Static-graph mode: Create placeholders and make a session call each
# time the wrapped function is called. Return this session call's
# outputs.
# time the wrapped function is called. Returns the output of this
# session call.
if session_or_none is not None:
args_placeholders = []
kwargs_placeholders = {}
symbolic_out = [None]
def call(*args, **kwargs):
@ -215,40 +234,42 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
else:
args_flat.append(a)
args = args_flat
# We have not built any placeholders yet: Do this once here,
# then reuse the same placeholders each time we call this
# function again.
if symbolic_out[0] is None:
with session_or_none.graph.as_default():
for i, v in enumerate(args):
def _create_placeholders(path, value):
if dynamic_shape:
if len(v.shape) > 0:
shape = (None, ) + v.shape[1:]
if len(value.shape) > 0:
shape = (None, ) + value.shape[1:]
else:
shape = ()
else:
shape = v.shape
args_placeholders.append(
tf1.placeholder(
dtype=v.dtype,
shape=shape,
name="arg_{}".format(i)))
for k, v in kwargs.items():
if dynamic_shape:
if len(v.shape) > 0:
shape = (None, ) + v.shape[1:]
else:
shape = ()
else:
shape = v.shape
kwargs_placeholders[k] = \
tf1.placeholder(
dtype=v.dtype,
shape=shape,
name="kwarg_{}".format(k))
shape = value.shape
return tf1.placeholder(
dtype=value.dtype,
shape=shape,
name=".".join([str(p) for p in path]),
)
placeholders = tree.map_structure_with_path(
_create_placeholders, args)
for ph in tree.flatten(placeholders):
args_placeholders.append(ph)
placeholders = tree.map_structure_with_path(
_create_placeholders, kwargs)
for k, ph in placeholders.items():
kwargs_placeholders[k] = ph
symbolic_out[0] = fn(*args_placeholders,
**kwargs_placeholders)
feed_dict = dict(zip(args_placeholders, args))
feed_dict.update(
{kwargs_placeholders[k]: kwargs[k]
for k in kwargs.keys()})
feed_dict = dict(zip(args_placeholders, tree.flatten(args)))
tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v),
kwargs_placeholders, kwargs)
ret = session_or_none.run(symbolic_out[0], feed_dict)
return ret