[RLlib] No Preprocessors (part 2). (#18468)

This commit is contained in:
Sven Mika 2021-09-23 12:56:45 +02:00 committed by GitHub
parent a2a077b874
commit 61a1274619
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 657 additions and 308 deletions

View file

@ -1460,7 +1460,7 @@ py_test(
py_test( py_test(
name = "test_preprocessors", name = "test_preprocessors",
tags = ["team:ml", "models"], tags = ["team:ml", "models"],
size = "small", size = "medium",
srcs = ["models/tests/test_preprocessors.py"] srcs = ["models/tests/test_preprocessors.py"]
) )
@ -2659,6 +2659,24 @@ py_test(
srcs = ["examples/pettingzoo_env.py"], srcs = ["examples/pettingzoo_env.py"],
) )
py_test(
name = "examples/preprocessing_disabled_tf",
main = "examples/preprocessing_disabled.py",
tags = ["team:ml", "examples", "examples_P"],
size = "medium",
srcs = ["examples/preprocessing_disabled.py"],
args = ["--stop-iters=2"]
)
py_test(
name = "examples/preprocessing_disabled_torch",
main = "examples/preprocessing_disabled.py",
tags = ["team:ml", "examples", "examples_P"],
size = "medium",
srcs = ["examples/preprocessing_disabled.py"],
args = ["--framework=torch", "--stop-iters=2"]
)
py_test( py_test(
name = "examples/remote_envs_with_inference_done_on_main_node_tf", name = "examples/remote_envs_with_inference_done_on_main_node_tf",
main = "examples/remote_envs_with_inference_done_on_main_node.py", main = "examples/remote_envs_with_inference_done_on_main_node.py",

View file

@ -30,12 +30,13 @@ from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override, \
from ray.rllib.utils.debug import update_global_seed_if_necessary from ray.rllib.utils.debug import update_global_seed_if_necessary
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR
from ray.rllib.utils.framework import try_import_tf, TensorStructType from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.from_config import from_config from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.multi_agent import check_multi_agent from ray.rllib.utils.multi_agent import check_multi_agent
from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.spaces import space_utils
from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \ from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \
PartialTrainerConfigDict, PolicyID, ResultDict, TrainerConfigDict PartialTrainerConfigDict, PolicyID, ResultDict, TensorStructType, \
TrainerConfigDict
from ray.tune.logger import Logger, UnifiedLogger from ray.tune.logger import Logger, UnifiedLogger
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.resources import Resources from ray.tune.resources import Resources
@ -113,11 +114,6 @@ COMMON_CONFIG: TrainerConfigDict = {
"model": MODEL_DEFAULTS, "model": MODEL_DEFAULTS,
# Arguments to pass to the policy optimizer. These vary by optimizer. # Arguments to pass to the policy optimizer. These vary by optimizer.
"optimizer": {}, "optimizer": {},
# Experimental flag, indicating that TFPolicy will handle more than one
# loss/optimizer. Set this to True, if you would like to return more than
# one loss term from your `loss_fn` and an equal number of optimizers
# from your `optimizer_fn`.
"_tf_policy_handles_more_than_one_loss": False,
# === Environment Settings === # === Environment Settings ===
# Number of steps after which the episode is forced to terminate. Defaults # Number of steps after which the episode is forced to terminate. Defaults
@ -483,6 +479,20 @@ COMMON_CONFIG: TrainerConfigDict = {
# Default value None allows overwriting with nested dicts # Default value None allows overwriting with nested dicts
"logger_config": None, "logger_config": None,
# === API deprecations/simplifications/changes ===
# Experimental flag.
# If True, TFPolicy will handle more than one loss/optimizer.
# Set this to True, if you would like to return more than
# one loss term from your `loss_fn` and an equal number of optimizers
# from your `optimizer_fn`.
# In the future, the default for this will be True.
"_tf_policy_handles_more_than_one_loss": False,
# Experimental flag.
# If True, no (observation) preprocessor will be created and
# observations will arrive in model as they are returned by the env.
# In the future, the default for this will be True.
"_disable_preprocessor_api": False,
# === Deprecated keys === # === Deprecated keys ===
# Uses the sync samples optimizer instead of the multi-gpu one. This is # Uses the sync samples optimizer instead of the multi-gpu one. This is
# usually slower, but you might want to try it if you run into issues with # usually slower, but you might want to try it if you run into issues with
@ -1128,8 +1138,8 @@ class Trainer(Trainable):
tuple: The full output of policy.compute_actions() if tuple: The full output of policy.compute_actions() if
full_fetch=True or we have an RNN-based Policy. full_fetch=True or we have an RNN-based Policy.
""" """
# Preprocess obs and states # Preprocess obs and states.
stateDefined = state is not None state_defined = state is not None
policy = self.get_policy(policy_id) policy = self.get_policy(policy_id)
filtered_obs, filtered_state = [], [] filtered_obs, filtered_state = [], []
for agent_id, ob in observations.items(): for agent_id, ob in observations.items():
@ -1174,7 +1184,7 @@ class Trainer(Trainable):
unbatched_states[agent_id] = [s[idx] for s in states] unbatched_states[agent_id] = [s[idx] for s in states]
# Return only actions or full tuple # Return only actions or full tuple
if stateDefined or full_fetch: if state_defined or full_fetch:
return actions, unbatched_states, infos return actions, unbatched_states, infos
else: else:
return actions return actions
@ -1529,8 +1539,8 @@ class Trainer(Trainable):
# Check model config. # Check model config.
# If no preprocessing, propagate into model's config as well # If no preprocessing, propagate into model's config as well
# (so model will know, whether inputs are preprocessed or not). # (so model will know, whether inputs are preprocessed or not).
if config["preprocessor_pref"] is None: if config["_disable_preprocessor_api"] is True:
model_config["_no_preprocessor"] = True model_config["_disable_preprocessor_api"] = True
# Prev_a/r settings. # Prev_a/r settings.
prev_a_r = model_config.get("lstm_use_prev_action_reward", prev_a_r = model_config.get("lstm_use_prev_action_reward",

View file

@ -3,6 +3,7 @@ from gym.spaces import Space
import logging import logging
import math import math
import numpy as np import numpy as np
import tree # pip install dm_tree
from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union
from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.env.base_env import _DUMMY_AGENT_ID
@ -14,6 +15,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import summarize from ray.rllib.utils.debug import summarize
from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \ from ray.rllib.utils.typing import AgentID, EpisodeID, EnvID, PolicyID, \
TensorType, ViewRequirementsDict TensorType, ViewRequirementsDict
from ray.util.debug import log_once from ray.util.debug import log_once
@ -47,7 +49,8 @@ class _AgentCollector:
_next_unroll_id = 0 # disambiguates unrolls within a single episode _next_unroll_id = 0 # disambiguates unrolls within a single episode
def __init__(self, view_reqs): def __init__(self, view_reqs, policy):
self.policy = policy
# Determine the size of the buffer we need for data before the actual # Determine the size of the buffer we need for data before the actual
# episode starts. This is used for 0-buffering of e.g. prev-actions, # episode starts. This is used for 0-buffering of e.g. prev-actions,
# or internal state inputs. # or internal state inputs.
@ -57,10 +60,28 @@ class _AgentCollector:
(1 (1
if vr.data_col == SampleBatch.OBS or k == SampleBatch.OBS else 0) if vr.data_col == SampleBatch.OBS or k == SampleBatch.OBS else 0)
for k, vr in view_reqs.items()) for k, vr in view_reqs.items())
# The actual data buffers (lists holding each timestep's data).
self.buffers: Dict[str, List] = {} # The actual data buffers. Keys are column names, values are lists
# that contain the sub-components (e.g. for complex obs spaces) with
# each sub-component holding a list of per-timestep tensors.
# E.g.: obs-space = Dict(a=Discrete(2), b=Box((2,)))
# buffers["obs"] = [
# [0, 1], # <- 1st sub-component of observation
# [np.array([.2, .3]), np.array([.0, -.2])] # <- 2nd sub-component
# ]
# NOTE: infos and state_out_... are not flattened due to them often
# using custom dict values whose structure may vary from timestep to
# timestep.
self.buffers: Dict[str, List[List[TensorType]]] = {}
# Maps column names to an example data item, which may be deeply
# nested. These are used such that we'll know how to unflatten
# the flattened data inside self.buffers when building the
# SampleBatch.
self.buffer_structs: Dict[str, Any] = {}
# The episode ID for the agent for which we collect data. # The episode ID for the agent for which we collect data.
self.episode_id = None self.episode_id = None
# The unroll ID, unique across all rollouts (within a RolloutWorker).
self.unroll_id = None
# The simple timestep count for this agent. Gets increased by one # The simple timestep count for this agent. Gets increased by one
# each time a (non-initial!) observation is added. # each time a (non-initial!) observation is added.
self.agent_steps = 0 self.agent_steps = 0
@ -80,6 +101,13 @@ class _AgentCollector:
init_obs (TensorType): The initial observation tensor (after init_obs (TensorType): The initial observation tensor (after
`env.reset()`). `env.reset()`).
""" """
# Store episode ID + unroll ID, which will be constant throughout this
# AgentCollector's lifecycle.
self.episode_id = episode_id
if self.unroll_id is None:
self.unroll_id = _AgentCollector._next_unroll_id
_AgentCollector._next_unroll_id += 1
if SampleBatch.OBS not in self.buffers: if SampleBatch.OBS not in self.buffers:
self._build_buffers( self._build_buffers(
single_row={ single_row={
@ -87,12 +115,19 @@ class _AgentCollector:
SampleBatch.AGENT_INDEX: agent_index, SampleBatch.AGENT_INDEX: agent_index,
SampleBatch.ENV_ID: env_id, SampleBatch.ENV_ID: env_id,
SampleBatch.T: t, SampleBatch.T: t,
SampleBatch.EPS_ID: self.episode_id,
SampleBatch.UNROLL_ID: self.unroll_id,
}) })
self.buffers[SampleBatch.OBS].append(init_obs)
self.episode_id = episode_id # Append data to existing buffers.
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index) flattened = tree.flatten(init_obs)
self.buffers[SampleBatch.ENV_ID].append(env_id) for i, sub_obs in enumerate(flattened):
self.buffers[SampleBatch.T].append(t) self.buffers[SampleBatch.OBS][i].append(sub_obs)
self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index)
self.buffers[SampleBatch.ENV_ID][0].append(env_id)
self.buffers[SampleBatch.T][0].append(t)
self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
None: None:
@ -103,20 +138,40 @@ class _AgentCollector:
row) to be added to buffer. Must contain keys: row) to be added to buffer. Must contain keys:
SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS. SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS.
""" """
if self.unroll_id is None:
self.unroll_id = _AgentCollector._next_unroll_id
_AgentCollector._next_unroll_id += 1
# Next obs -> obs.
assert SampleBatch.OBS not in values assert SampleBatch.OBS not in values
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS] values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
del values[SampleBatch.NEXT_OBS] del values[SampleBatch.NEXT_OBS]
# Make sure EPS_ID stays the same for this agent. Usually, it should
# not be part of `values` anyways. # Make sure EPS_ID/UNROLL_ID stay the same for this agent.
if SampleBatch.EPS_ID in values: if SampleBatch.EPS_ID in values:
assert values[SampleBatch.EPS_ID] == self.episode_id assert values[SampleBatch.EPS_ID] == self.episode_id
del values[SampleBatch.EPS_ID] del values[SampleBatch.EPS_ID]
self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
if SampleBatch.UNROLL_ID in values:
assert values[SampleBatch.UNROLL_ID] == self.unroll_id
del values[SampleBatch.UNROLL_ID]
self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
for k, v in values.items(): for k, v in values.items():
if k not in self.buffers: if k not in self.buffers:
self._build_buffers(single_row=values) self._build_buffers(single_row=values)
self.buffers[k].append(v) # Do not flatten infos, state_out_ and actions.
# Infos/state-outs may be structs that change from timestep to
# timestep. Actions - on the other hand - are already flattened
# in the sampler.
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS
] or k.startswith("state_out_"):
self.buffers[k][0].append(v)
# Flatten all other columns.
else:
flattened = tree.flatten(v)
for i, sub_list in enumerate(self.buffers[k]):
sub_list.append(flattened[i])
self.agent_steps += 1 self.agent_steps += 1
def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
@ -157,7 +212,9 @@ class _AgentCollector:
# Keep an np-array cache so we don't have to regenerate the # Keep an np-array cache so we don't have to regenerate the
# np-array for different view_cols using to the same data_col. # np-array for different view_cols using to the same data_col.
if data_col not in np_data: if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col]) np_data[data_col] = [
to_float_np_array(d) for d in self.buffers[data_col]
]
# Range of indices on time-axis, e.g. "-50:-1". Together with # Range of indices on time-axis, e.g. "-50:-1". Together with
# the `batch_repeat_value`, this determines the data produced. # the `batch_repeat_value`, this determines the data produced.
@ -171,42 +228,50 @@ class _AgentCollector:
# every n timesteps. # every n timesteps.
if view_req.batch_repeat_value > 1: if view_req.batch_repeat_value > 1:
count = int( count = int(
math.ceil((len(np_data[data_col]) - self.shift_before) math.ceil(
/ view_req.batch_repeat_value)) (len(np_data[data_col][0]) - self.shift_before) /
data = np.asarray([ view_req.batch_repeat_value))
np_data[data_col][self.shift_before + data = [
(i * view_req.batch_repeat_value) + np.asarray([
view_req.shift_from + d[self.shift_before +
obs_shift:self.shift_before + (i * view_req.batch_repeat_value) +
(i * view_req.batch_repeat_value) + view_req.shift_from +
view_req.shift_to + 1 + obs_shift] obs_shift:self.shift_before +
for i in range(count) (i * view_req.batch_repeat_value) +
]) view_req.shift_to + 1 + obs_shift]
for i in range(count)
]) for d in np_data[data_col]
]
# Batch repeat value = 1: Repeat the shift_from/to range at # Batch repeat value = 1: Repeat the shift_from/to range at
# each timestep. # each timestep.
else: else:
d = np_data[data_col] d0 = np_data[data_col][0]
shift_win = view_req.shift_to - view_req.shift_from + 1 shift_win = view_req.shift_to - view_req.shift_from + 1
data_size = d.itemsize * int(np.product(d.shape[1:])) data_size = d0.itemsize * int(np.product(d0.shape[1:]))
strides = [ strides = [
d.itemsize * int(np.product(d.shape[i + 1:])) d0.itemsize * int(np.product(d0.shape[i + 1:]))
for i in range(1, len(d.shape)) for i in range(1, len(d0.shape))
] ]
start = self.shift_before - shift_win + 1 + obs_shift + \ start = self.shift_before - shift_win + 1 + obs_shift + \
view_req.shift_to view_req.shift_to
data = np.lib.stride_tricks.as_strided( data = [
d[start:start + self.agent_steps], np.lib.stride_tricks.as_strided(
[self.agent_steps, shift_win d[start:start + self.agent_steps],
] + [d.shape[i] for i in range(1, len(d.shape))], [self.agent_steps, shift_win
[data_size, data_size] + strides) ] + [d.shape[i] for i in range(1, len(d.shape))],
[data_size, data_size] + strides)
for d in np_data[data_col]
]
# Set of (probably non-consecutive) indices. # Set of (probably non-consecutive) indices.
# Example: # Example:
# shift=[-3, 0] # shift=[-3, 0]
# buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...] # resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...]
elif isinstance(view_req.shift, np.ndarray): elif isinstance(view_req.shift, np.ndarray):
data = np_data[data_col][self.shift_before + obs_shift + data = [
view_req.shift] d[self.shift_before + obs_shift + view_req.shift]
for d in np_data[data_col]
]
# Single shift int value. Use the trajectory as-is, and if # Single shift int value. Use the trajectory as-is, and if
# `shift` != 0: shifted by that value. # `shift` != 0: shifted by that value.
else: else:
@ -215,58 +280,77 @@ class _AgentCollector:
# Batch repeat (only provide a value every n timesteps). # Batch repeat (only provide a value every n timesteps).
if view_req.batch_repeat_value > 1: if view_req.batch_repeat_value > 1:
count = int( count = int(
math.ceil((len(np_data[data_col]) - self.shift_before) math.ceil(
/ view_req.batch_repeat_value)) (len(np_data[data_col][0]) - self.shift_before) /
data = np.asarray([ view_req.batch_repeat_value))
np_data[data_col][self.shift_before + ( data = [
i * view_req.batch_repeat_value) + shift] np.asarray([
for i in range(count) d[self.shift_before +
]) (i * view_req.batch_repeat_value) + shift]
for i in range(count)
]) for d in np_data[data_col]
]
# Shift is exactly 0: Use trajectory as is. # Shift is exactly 0: Use trajectory as is.
elif shift == 0: elif shift == 0:
data = np_data[data_col][self.shift_before:] data = [d[self.shift_before:] for d in np_data[data_col]]
# Shift is positive: We still need to 0-pad at the end. # Shift is positive: We still need to 0-pad at the end.
elif shift > 0: elif shift > 0:
data = to_float_np_array( data = [
self.buffers[data_col][self.shift_before + shift:] + [ to_float_np_array(
np.zeros( np.concatenate([
shape=view_req.space.shape, d[self.shift_before + shift:], [
dtype=view_req.space.dtype) np.zeros(
for _ in range(shift) shape=view_req.space.shape,
]) dtype=view_req.space.dtype)
for _ in range(shift)
]
])) for d in np_data[data_col]
]
# Shift is negative: Shift into the already existing and # Shift is negative: Shift into the already existing and
# 0-padded "before" area of our buffers. # 0-padded "before" area of our buffers.
else: else:
data = np_data[data_col][self.shift_before + shift:shift] data = [
d[self.shift_before + shift:shift]
for d in np_data[data_col]
]
if len(data) > 0: if len(data) > 0:
batch_data[view_col] = data if data_col not in self.buffer_structs:
batch_data[view_col] = data[0]
else:
batch_data[view_col] = tree.unflatten_as(
self.buffer_structs[data_col], data)
# Due to possible batch-repeats > 1, columns in the resulting batch # Due to possible batch-repeats > 1, columns in the resulting batch
# may not all have the same batch size. # may not all have the same batch size.
batch = SampleBatch(batch_data) batch = SampleBatch(batch_data)
# Add EPS_ID and UNROLL_ID to batch. # Adjust the seq-lens array depending on the incoming agent sequences.
batch[SampleBatch.EPS_ID] = np.repeat(self.episode_id, batch.count) if self.policy.is_recurrent():
if SampleBatch.UNROLL_ID not in batch: seq_lens = []
# TODO: (sven) Once we have the additional max_seq_len = self.policy.config["model"]["max_seq_len"]
# model.preprocess_train_batch in place (attention net PR), we count = batch.count
# should not even need UNROLL_ID anymore: while count > 0:
# Add "if SampleBatch.UNROLL_ID in view_requirements:" here. seq_lens.append(min(count, max_seq_len))
batch[SampleBatch.UNROLL_ID] = np.repeat( count -= max_seq_len
_AgentCollector._next_unroll_id, batch.count) batch["seq_lens"] = np.array(seq_lens)
_AgentCollector._next_unroll_id += 1 batch.max_seq_len = max_seq_len
# This trajectory is continuing -> Copy data at the end (in the size of # This trajectory is continuing -> Copy data at the end (in the size of
# self.shift_before) to the beginning of buffers and erase everything # self.shift_before) to the beginning of buffers and erase everything
# else. # else.
if not self.buffers[SampleBatch.DONES][-1]: if not self.buffers[SampleBatch.DONES][0][-1]:
# Copy data to beginning of buffer and cut lists. # Copy data to beginning of buffer and cut lists.
if self.shift_before > 0: if self.shift_before > 0:
for k, data in self.buffers.items(): for k, data in self.buffers.items():
self.buffers[k] = data[-self.shift_before:] # Loop through
for i in range(len(data)):
self.buffers[k][i] = data[i][-self.shift_before:]
self.agent_steps = 0 self.agent_steps = 0
# Reset our unroll_id.
self.unroll_id = None
return batch return batch
def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
@ -279,12 +363,25 @@ class _AgentCollector:
for col, data in single_row.items(): for col, data in single_row.items():
if col in self.buffers: if col in self.buffers:
continue continue
shift = self.shift_before - (1 if col in [ shift = self.shift_before - (1 if col in [
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
SampleBatch.ENV_ID, SampleBatch.T SampleBatch.ENV_ID, SampleBatch.T, SampleBatch.UNROLL_ID
] else 0) ] else 0)
# Python primitive, tensor, or dict (e.g. INFOs).
self.buffers[col] = [data for _ in range(shift)] # Store all data as flattened lists, except INFOS and state-out
# lists. These are monolithic items (infos is a dict that
# should not be further split, same for state-out items, which
# could be custom dicts as well).
if col in [SampleBatch.INFOS, SampleBatch.ACTIONS
] or col.startswith("state_out_"):
self.buffers[col] = [[data for _ in range(shift)]]
else:
self.buffers[col] = [[v for _ in range(shift)]
for v in tree.flatten(data)]
# Store an example data struct so we know, how to unflatten
# each data col.
self.buffer_structs[col] = data
class _PolicyCollector: class _PolicyCollector:
@ -302,15 +399,13 @@ class _PolicyCollector:
policy (Policy): The policy object. policy (Policy): The policy object.
""" """
self.buffers: Dict[str, List] = collections.defaultdict(list) self.batches = []
self.policy = policy self.policy = policy
# The total timestep count for all agents that use this policy. # The total timestep count for all agents that use this policy.
# NOTE: This is not an env-step count (across n agents). AgentA and # NOTE: This is not an env-step count (across n agents). AgentA and
# agentB, both using this policy, acting in the same episode and both # agentB, both using this policy, acting in the same episode and both
# doing n steps would increase the count by 2*n. # doing n steps would increase the count by 2*n.
self.agent_steps = 0 self.agent_steps = 0
# Seq-lens list of already added agent batches.
self.seq_lens = [] if policy.is_recurrent() else None
def add_postprocessed_batch_for_training( def add_postprocessed_batch_for_training(
self, batch: SampleBatch, self, batch: SampleBatch,
@ -325,22 +420,13 @@ class _PolicyCollector:
view-column needs to be copied at all (not needed for view-column needs to be copied at all (not needed for
training). training).
""" """
for view_col, data in batch.items():
# 1) If col is not in view_requirements, we must have a direct
# child of the base Policy that doesn't do auto-view req creation.
# 2) Col is in view-reqs and needed for training.
view_req = view_requirements.get(view_col)
if view_req is None or view_req.used_for_training:
self.buffers[view_col].extend(data)
# Add the agent's trajectory length to our count. # Add the agent's trajectory length to our count.
self.agent_steps += batch.count self.agent_steps += batch.count
# Adjust the seq-lens array depending on the incoming agent sequences. # And remove columns not needed for training.
if self.seq_lens is not None: for view_col, view_req in view_requirements.items():
max_seq_len = self.policy.config["model"]["max_seq_len"] if view_col in batch and not view_req.used_for_training:
count = batch.count del batch[view_col]
while count > 0: self.batches.append(batch)
self.seq_lens.append(min(count, max_seq_len))
count -= max_seq_len
def build(self): def build(self):
"""Builds a SampleBatch for this policy from the collected data. """Builds a SampleBatch for this policy from the collected data.
@ -352,13 +438,11 @@ class _PolicyCollector:
this policy. this policy.
""" """
# Create batch from our buffers. # Create batch from our buffers.
batch = SampleBatch(self.buffers, seq_lens=self.seq_lens) batch = SampleBatch.concat_samples(self.batches)
# Clear buffers for future samples. # Clear batches for future samples.
self.buffers.clear() self.batches = []
# Reset agent steps to 0 and seq-lens to empty list. # Reset agent steps to 0 and seq-lens to empty list.
self.agent_steps = 0 self.agent_steps = 0
if self.seq_lens is not None:
self.seq_lens = []
return batch return batch
@ -479,7 +563,7 @@ class SimpleListCollector(SampleCollector):
# Add initial obs to Trajectory. # Add initial obs to Trajectory.
assert agent_key not in self.agent_collectors assert agent_key not in self.agent_collectors
# TODO: determine exact shift-before based on the view-req shifts. # TODO: determine exact shift-before based on the view-req shifts.
self.agent_collectors[agent_key] = _AgentCollector(view_reqs) self.agent_collectors[agent_key] = _AgentCollector(view_reqs, policy)
self.agent_collectors[agent_key].add_init_obs( self.agent_collectors[agent_key].add_init_obs(
episode_id=episode.episode_id, episode_id=episode.episode_id,
agent_index=episode._agent_index(agent_id), agent_index=episode._agent_index(agent_id),
@ -537,7 +621,13 @@ class SimpleListCollector(SampleCollector):
Dict[str, TensorType]: Dict[str, TensorType]:
policy = self.policy_map[policy_id] policy = self.policy_map[policy_id]
keys = self.forward_pass_agent_keys[policy_id] keys = self.forward_pass_agent_keys[policy_id]
buffers = {k: self.agent_collectors[k].buffers for k in keys}
buffers = {}
for k in keys:
collector = self.agent_collectors[k]
buffers[k] = collector.buffers
# Use one agent's buffer_structs (they should all be the same).
buffer_structs = self.agent_collectors[keys[0]].buffer_structs
input_dict = {} input_dict = {}
for view_col, view_req in policy.view_requirements.items(): for view_col, view_req in policy.view_requirements.items():
@ -558,33 +648,57 @@ class SimpleListCollector(SampleCollector):
# Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0]. # Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0].
else: else:
time_indices = view_req.shift + delta time_indices = view_req.shift + delta
data_list = []
# Loop through agents and add-up their data (batch). # Loop through agents and add up their data (batch).
data = None
for k in keys: for k in keys:
if data_col == SampleBatch.EPS_ID: # Buffer for the data does not exist yet: Create dummy
data_list.append(self.agent_collectors[k].episode_id) # (zero) data.
else: if data_col not in buffers[k]:
if data_col not in buffers[k]: if view_req.data_col is not None:
if view_req.data_col is not None: space = policy.view_requirements[
space = policy.view_requirements[ view_req.data_col].space
view_req.data_col].space
else:
space = view_req.space
fill_value = np.zeros_like(space.sample()) \
if isinstance(space, Space) else space
self.agent_collectors[k]._build_buffers({
data_col: fill_value
})
if isinstance(time_indices, tuple):
if time_indices[1] == -1:
data_list.append(
buffers[k][data_col][time_indices[0]:])
else:
data_list.append(buffers[k][data_col][time_indices[
0]:time_indices[1] + 1])
else: else:
data_list.append(buffers[k][data_col][time_indices]) space = view_req.space
input_dict[view_col] = np.array(data_list)
if isinstance(space, Space):
fill_value = get_dummy_batch_for_space(
space,
batch_size=0,
)
else:
fill_value = space
self.agent_collectors[k]._build_buffers({
data_col: fill_value
})
if data is None:
data = [[] for _ in range(len(buffers[keys[0]][data_col]))]
# `shift_from` and `shift_to` are defined: User wants a
# view with some time-range.
if isinstance(time_indices, tuple):
# `shift_to` == -1: Until the end (including(!) the
# last item).
if time_indices[1] == -1:
for d, b in zip(data, buffers[k][data_col]):
d.append(b[time_indices[0]:])
# `shift_to` != -1: "Normal" range.
else:
for d, b in zip(data, buffers[k][data_col]):
d.append(b[time_indices[0]:time_indices[1] + 1])
# Single index.
else:
for d, b in zip(data, buffers[k][data_col]):
d.append(b[time_indices])
np_data = [np.array(d) for d in data]
if data_col in buffer_structs:
input_dict[view_col] = tree.unflatten_as(
buffer_structs[data_col], np_data)
else:
input_dict[view_col] = np_data[0]
self._reset_inference_calls(policy_id) self._reset_inference_calls(policy_id)

View file

@ -20,7 +20,7 @@ from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import NoPreprocessor, Preprocessor from ray.rllib.models.preprocessors import Preprocessor
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
OffPolicyEstimate OffPolicyEstimate
@ -44,7 +44,7 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \ from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \
ModelConfigDict, ModelGradients, ModelWeights, \ ModelConfigDict, ModelGradients, ModelWeights, \
MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \ MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \
SampleBatchType, TrainerConfigDict SampleBatchType
from ray.util.debug import log_once, disable_log_once_globally, \ from ray.util.debug import log_once, disable_log_once_globally, \
enable_periodic_logging enable_periodic_logging
from ray.util.iter import ParallelIteratorWorker from ray.util.iter import ParallelIteratorWorker
@ -168,7 +168,8 @@ class RolloutWorker(ParallelIteratorWorker):
env_creator: Callable[[EnvContext], EnvType], env_creator: Callable[[EnvContext], EnvType],
validate_env: Optional[Callable[[EnvType, EnvContext], validate_env: Optional[Callable[[EnvType, EnvContext],
None]] = None, None]] = None,
policy_spec: Union[type, Dict[PolicyID, PolicySpec]] = None, policy_spec: Optional[Union[type, Dict[PolicyID,
PolicySpec]]] = None,
policy_mapping_fn: Optional[Callable[ policy_mapping_fn: Optional[Callable[
[AgentID, "MultiAgentEpisode"], PolicyID]] = None, [AgentID, "MultiAgentEpisode"], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None, policies_to_train: Optional[List[PolicyID]] = None,
@ -176,24 +177,24 @@ class RolloutWorker(ParallelIteratorWorker):
rollout_fragment_length: int = 100, rollout_fragment_length: int = 100,
count_steps_by: str = "env_steps", count_steps_by: str = "env_steps",
batch_mode: str = "truncate_episodes", batch_mode: str = "truncate_episodes",
episode_horizon: int = None, episode_horizon: Optional[int] = None,
preprocessor_pref: Optional[str] = "deepmind", preprocessor_pref: str = "deepmind",
sample_async: bool = False, sample_async: bool = False,
compress_observations: bool = False, compress_observations: bool = False,
num_envs: int = 1, num_envs: int = 1,
observation_fn: "ObservationFunction" = None, observation_fn: Optional["ObservationFunction"] = None,
observation_filter: str = "NoFilter", observation_filter: str = "NoFilter",
clip_rewards: Optional[Union[bool, float]] = None, clip_rewards: Optional[Union[bool, float]] = None,
normalize_actions: bool = True, normalize_actions: bool = True,
clip_actions: bool = False, clip_actions: bool = False,
env_config: EnvConfigDict = None, env_config: Optional[EnvConfigDict] = None,
model_config: ModelConfigDict = None, model_config: Optional[ModelConfigDict] = None,
policy_config: TrainerConfigDict = None, policy_config: Optional[PartialTrainerConfigDict] = None,
worker_index: int = 0, worker_index: int = 0,
num_workers: int = 0, num_workers: int = 0,
record_env: Union[bool, str] = False, record_env: Union[bool, str] = False,
log_dir: str = None, log_dir: Optional[str] = None,
log_level: str = None, log_level: Optional[str] = None,
callbacks: Type["DefaultCallbacks"] = None, callbacks: Type["DefaultCallbacks"] = None,
input_creator: Callable[[ input_creator: Callable[[
IOContext IOContext
@ -206,7 +207,7 @@ class RolloutWorker(ParallelIteratorWorker):
soft_horizon: bool = False, soft_horizon: bool = False,
no_done_at_end: bool = False, no_done_at_end: bool = False,
seed: int = None, seed: int = None,
extra_python_environs: dict = None, extra_python_environs: Optional[dict] = None,
fake_sampler: bool = False, fake_sampler: bool = False,
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space, spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]] = None, gym.spaces.Space]]] = None,
@ -258,10 +259,10 @@ class RolloutWorker(ParallelIteratorWorker):
that when `num_envs > 1`, episode steps will be buffered that when `num_envs > 1`, episode steps will be buffered
until the episode completes, and hence batches may contain until the episode completes, and hence batches may contain
significant amounts of off-policy data. significant amounts of off-policy data.
episode_horizon (int): Whether to stop episodes at this horizon. episode_horizon: Horizon at which to stop episodes (even if the
preprocessor_pref (Optional[str]): Whether to use no preprocessor environment itself has not retured a "done" signal).
(None), RLlib preprocessors ("rllib") or deepmind ("deepmind"), preprocessor_pref (str): Whether to use RLlib preprocessors
when applicable. ("rllib") or deepmind ("deepmind"), when applicable.
sample_async (bool): Whether to compute samples asynchronously in sample_async (bool): Whether to compute samples asynchronously in
the background, which improves throughput but can cause samples the background, which improves throughput but can cause samples
to be slightly off-policy. to be slightly off-policy.
@ -284,9 +285,9 @@ class RolloutWorker(ParallelIteratorWorker):
env_config (EnvConfigDict): Config to pass to the env creator. env_config (EnvConfigDict): Config to pass to the env creator.
model_config (ModelConfigDict): Config to use when creating the model_config (ModelConfigDict): Config to use when creating the
policy model. policy model.
policy_config (TrainerConfigDict): Config to pass to the policy. policy_config: Config to pass to the
In the multi-agent case, this config will be merged with the policy. In the multi-agent case, this config will be merged
per-policy configs specified by `policy_spec`. with the per-policy configs specified by `policy_spec`.
worker_index (int): For remote workers, this should be set to a worker_index (int): For remote workers, this should be set to a
non-zero and unique value. This index is passed to created envs non-zero and unique value. This index is passed to created envs
through EnvContext so that envs can be configured per worker. through EnvContext so that envs can be configured per worker.
@ -378,7 +379,7 @@ class RolloutWorker(ParallelIteratorWorker):
ParallelIteratorWorker.__init__(self, gen_rollouts, False) ParallelIteratorWorker.__init__(self, gen_rollouts, False)
policy_config: TrainerConfigDict = policy_config or {} policy_config = policy_config or {}
if (tf1 and policy_config.get("framework") in ["tf2", "tfe"] if (tf1 and policy_config.get("framework") in ["tf2", "tfe"]
# This eager check is necessary for certain all-framework tests # This eager check is necessary for certain all-framework tests
# that use tf's eager_mode() context generator. # that use tf's eager_mode() context generator.
@ -400,7 +401,7 @@ class RolloutWorker(ParallelIteratorWorker):
num_workers=num_workers, num_workers=num_workers,
) )
self.env_context = env_context self.env_context = env_context
self.policy_config: TrainerConfigDict = policy_config self.policy_config: PartialTrainerConfigDict = policy_config
if callbacks: if callbacks:
self.callbacks: "DefaultCallbacks" = callbacks() self.callbacks: "DefaultCallbacks" = callbacks()
else: else:
@ -424,10 +425,10 @@ class RolloutWorker(ParallelIteratorWorker):
self.batch_mode: str = batch_mode self.batch_mode: str = batch_mode
self.compress_observations: bool = compress_observations self.compress_observations: bool = compress_observations
self.preprocessing_enabled: bool = False \ self.preprocessing_enabled: bool = False \
if preprocessor_pref is None else True if policy_config.get("_disable_preprocessor_api") else True
self.observation_filter = observation_filter self.observation_filter = observation_filter
self.last_batch: SampleBatchType = None self.last_batch: Optional[SampleBatchType] = None
self.global_vars: dict = None self.global_vars: Optional[dict] = None
self.fake_sampler: bool = fake_sampler self.fake_sampler: bool = fake_sampler
# Update the global seed for numpy/random/tf-eager/torch if we are not # Update the global seed for numpy/random/tf-eager/torch if we are not
@ -1076,10 +1077,9 @@ class RolloutWorker(ParallelIteratorWorker):
space of the policy to add. space of the policy to add.
action_space (Optional[gym.spaces.Space]): The action space action_space (Optional[gym.spaces.Space]): The action space
of the policy to add. of the policy to add.
config (Optional[PartialTrainerConfigDict]): The config config: The config overrides for the policy to add.
overrides for the policy to add. policy_config: The base config of the Trainer object owning this
policy_config (Optional[TrainerConfigDict]): The base config of the RolloutWorker.
Trainer object owning this RolloutWorker.
policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode], policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode],
PolicyID]]): An optional (updated) policy mapping function to PolicyID]]): An optional (updated) policy mapping function to
use from here on. Note that already ongoing episodes will not use from here on. Note that already ongoing episodes will not
@ -1340,7 +1340,7 @@ class RolloutWorker(ParallelIteratorWorker):
def _build_policy_map( def _build_policy_map(
self, self,
policy_dict: MultiAgentPolicyConfigDict, policy_dict: MultiAgentPolicyConfigDict,
policy_config: TrainerConfigDict, policy_config: PartialTrainerConfigDict,
session_creator: Optional[Callable[[], "tf1.Session"]] = None, session_creator: Optional[Callable[[], "tf1.Session"]] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]: ) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]:
@ -1371,13 +1371,7 @@ class RolloutWorker(ParallelIteratorWorker):
if preprocessor is not None: if preprocessor is not None:
obs_space = preprocessor.observation_space obs_space = preprocessor.observation_space
else: else:
self.preprocessors[name] = NoPreprocessor(obs_space) self.preprocessors[name] = None
if isinstance(obs_space, (gym.spaces.Dict, gym.spaces.Tuple)):
raise ValueError(
"Found raw Tuple|Dict space as input to policy. "
"Please preprocess these observations with a "
"Tuple|DictFlatteningPreprocessor.")
self.policy_map.create_policy(name, orig_cls, obs_space, act_space, self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
conf, merged_conf) conf, merged_conf)

View file

@ -275,14 +275,14 @@ class AsyncSampler(threading.Thread, SamplerInput):
rollout_fragment_length: int, rollout_fragment_length: int,
count_steps_by: str = "env_steps", count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks", callbacks: "DefaultCallbacks",
horizon: int = None, horizon: Optional[int] = None,
multiple_episodes_in_batch: bool = False, multiple_episodes_in_batch: bool = False,
normalize_actions: bool = True, normalize_actions: bool = True,
clip_actions: bool = False, clip_actions: bool = False,
blackhole_outputs: bool = False, blackhole_outputs: bool = False,
soft_horizon: bool = False, soft_horizon: bool = False,
no_done_at_end: bool = False, no_done_at_end: bool = False,
observation_fn: "ObservationFunction" = None, observation_fn: Optional["ObservationFunction"] = None,
sample_collector_class: Optional[Type[SampleCollector]] = None, sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False, render: bool = False,
# Obsolete. # Obsolete.
@ -308,7 +308,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
Refers to the unit of `rollout_fragment_length`. Refers to the unit of `rollout_fragment_length`.
callbacks (Callbacks): The Callbacks object to use when episode callbacks (Callbacks): The Callbacks object to use when episode
events happen during rollout. events happen during rollout.
horizon (Optional[int]): Hard-reset the Env horizon: Hard-reset the Env after this many timesteps.
multiple_episodes_in_batch (bool): Whether to pack multiple multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size. exactly `rollout_fragment_length` in size.
@ -452,7 +452,7 @@ def _env_runner(
worker: "RolloutWorker", worker: "RolloutWorker",
base_env: BaseEnv, base_env: BaseEnv,
extra_batch_callback: Callable[[SampleBatchType], None], extra_batch_callback: Callable[[SampleBatchType], None],
horizon: int, horizon: Optional[int],
normalize_actions: bool, normalize_actions: bool,
clip_actions: bool, clip_actions: bool,
multiple_episodes_in_batch: bool, multiple_episodes_in_batch: bool,
@ -470,7 +470,7 @@ def _env_runner(
worker (RolloutWorker): Reference to the current rollout worker. worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv. base_env (BaseEnv): Env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to. extra_batch_callback (fn): function to send extra batch data to.
horizon (int): Horizon of the episode. horizon: Horizon of the episode.
multiple_episodes_in_batch (bool): Whether to pack multiple multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size. `rollout_fragment_length` in size.

View file

@ -117,21 +117,24 @@ class TestRolloutWorker(unittest.TestCase):
ev.stop() ev.stop()
def test_batch_ids(self): def test_batch_ids(self):
fragment_len = 100
ev = RolloutWorker( ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"), env_creator=lambda _: gym.make("CartPole-v0"),
policy_spec=MockPolicy, policy_spec=MockPolicy,
rollout_fragment_length=1) rollout_fragment_length=fragment_len)
batch1 = ev.sample() batch1 = ev.sample()
batch2 = ev.sample() batch2 = ev.sample()
self.assertEqual(len(set(batch1["unroll_id"])), 1) unroll_ids_1 = set(batch1["unroll_id"])
self.assertEqual(len(set(batch2["unroll_id"])), 1) unroll_ids_2 = set(batch2["unroll_id"])
self.assertEqual( # Assert no overlap of unroll IDs between sample() calls.
len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2) self.assertTrue(not any(uid in unroll_ids_2 for uid in unroll_ids_1))
# CartPole episodes should be short initially: Expect more than one
# unroll ID in each batch.
self.assertTrue(len(unroll_ids_1) > 1)
self.assertTrue(len(unroll_ids_2) > 1)
ev.stop() ev.stop()
def test_global_vars_update(self): def test_global_vars_update(self):
# Allow for Unittest run.
ray.init(num_cpus=5, ignore_reinit_error=True)
for fw in framework_iterator(frameworks=("tf2", "tf")): for fw in framework_iterator(frameworks=("tf2", "tf")):
agent = A2CTrainer( agent = A2CTrainer(
env="CartPole-v0", env="CartPole-v0",

View file

@ -185,31 +185,49 @@ class TestTrajectoryViewAPI(unittest.TestCase):
policy_mapping_fn=None, policy_mapping_fn=None,
num_envs=1, num_envs=1,
) )
# Add the next action to the view reqs of the policy. # Add the next action (a') and 2nd next action (a'') to the view
# requirements of the policy.
# This should be visible then in postprocessing and train batches. # This should be visible then in postprocessing and train batches.
# Switch off for action computations (can't be there as we don't know # Switch off for action computations (can't be there as we don't know
# the next action already at action computation time). # the next actions already at action computation time).
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[ rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
"next_actions"] = ViewRequirement( "next_actions"] = ViewRequirement(
SampleBatch.ACTIONS, SampleBatch.ACTIONS,
shift=1, shift=1,
space=action_space, space=action_space,
used_for_compute_actions=False) used_for_compute_actions=False)
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
"2nd_next_actions"] = ViewRequirement(
SampleBatch.ACTIONS,
shift=2,
space=action_space,
used_for_compute_actions=False)
# Make sure, we have DONEs as well. # Make sure, we have DONEs as well.
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[ rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
"dones"] = ViewRequirement() "dones"] = ViewRequirement()
batch = rollout_worker_w_api.sample() batch = rollout_worker_w_api.sample()
self.assertTrue("next_actions" in batch) self.assertTrue("next_actions" in batch)
self.assertTrue("2nd_next_actions" in batch)
expected_a_ = None # expected next action expected_a_ = None # expected next action
expected_a__ = None # expected 2nd next action
for i in range(len(batch["actions"])): for i in range(len(batch["actions"])):
a, d, a_ = batch["actions"][i], batch["dones"][i], \ a, d, a_, a__ = \
batch["next_actions"][i] batch["actions"][i], batch["dones"][i], \
if not d and expected_a_ is not None: batch["next_actions"][i], batch["2nd_next_actions"][i]
check(a, expected_a_) # Episode done: next action and 2nd next action should be 0.
elif d: if d:
check(a_, 0) check(a_, 0)
check(a__, 0)
expected_a_ = None expected_a_ = None
expected_a__ = None
continue continue
# Episode is not done and we have an expected next-a.
if expected_a_ is not None:
check(a, expected_a_)
if expected_a__ is not None:
check(a_, expected_a__)
expected_a__ = a__
expected_a_ = a_ expected_a_ = a_
def test_traj_view_lstm_functionality(self): def test_traj_view_lstm_functionality(self):

View file

@ -57,7 +57,7 @@ class MyCallbacks(DefaultCallbacks):
env_index: int, **kwargs): env_index: int, **kwargs):
# Make sure this episode is really done. # Make sure this episode is really done.
assert episode.batch_builder.policy_collectors[ assert episode.batch_builder.policy_collectors[
"default_policy"].buffers["dones"][-1], \ "default_policy"].batches[-1]["dones"][-1], \
"ERROR: `on_episode_end()` should only be called " \ "ERROR: `on_episode_end()` should only be called " \
"after episode is done!" "after episode is done!"
pole_angle = np.mean(episode.user_data["pole_angles"]) pole_angle = np.mean(episode.user_data["pole_angles"])

View file

@ -49,6 +49,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space), SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(), SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(), SampleBatch.DONES: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
}, },
**self.model.view_requirements) **self.model.view_requirements)

View file

@ -0,0 +1,107 @@
"""
Example for using _disable_preprocessor_api=True to disable all preprocessing.
This example shows:
- How a complex observation space from the env is handled directly by the
model.
- Complex observations are flattened into lists of tensors and as such
stored by the SampleCollectors.
- This has the advantage that preprocessing happens in batched fashion
(in the model).
"""
import argparse
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
import os
import ray
from ray import tune
def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()
# general args
parser.add_argument(
"--run", default="PPO", help="The RLlib-registered algorithm to use.")
parser.add_argument("--num-cpus", type=int, default=3)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument(
"--stop-iters",
type=int,
default=200,
help="Number of iterations to train.")
parser.add_argument(
"--stop-timesteps",
type=int,
default=500000,
help="Number of timesteps to train.")
parser.add_argument(
"--stop-reward",
type=float,
default=80.0,
help="Reward at which we stop training.")
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.")
parser.add_argument(
"--no-tune",
action="store_true",
help="Run without Tune using a manual train loop instead. Here,"
"there is no TensorBoard support.")
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.")
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args
if __name__ == "__main__":
args = get_cli_args()
ray.init(local_mode=args.local_mode)
config = {
"env": "ray.rllib.examples.env.random_env.RandomEnv",
"env_config": {
"config": {
"observation_space": Dict({
"a": Discrete(2),
"b": Dict({
"ba": Discrete(3),
"bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32)
}),
"c": Tuple((MultiDiscrete([2, 3]), Discrete(2))),
"d": Box(-1.0, 1.0, (2, ), dtype=np.int32),
}),
},
},
# Set this to True to enforce no preprocessors being used.
# Complex observations now arrive directly in the model as
# structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]}
# for obs-space=Dict(a=..., b=Tuple(..., ...)).
"_disable_preprocessor_api": True,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", 0)),
"framework": args.framework,
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, config=config, stop=stop, verbose=2)
ray.shutdown()

View file

@ -502,7 +502,8 @@ class LocalReplayBuffer(ParallelIteratorWorker):
# If SampleBatch has prio-replay weights, average # If SampleBatch has prio-replay weights, average
# over these to use as a weight for the entire # over these to use as a weight for the entire
# sequence. # sequence.
if "weights" in time_slice: if "weights" in time_slice and \
len(time_slice["weights"]):
weight = np.mean(time_slice["weights"]) weight = np.mean(time_slice["weights"])
else: else:
weight = None weight = None

View file

@ -46,9 +46,9 @@ MODEL_DEFAULTS: ModelConfigDict = {
"_use_default_native_models": False, "_use_default_native_models": False,
# Experimental flag. # Experimental flag.
# If True, user specified no preprocessor to be created # If True, user specified no preprocessor to be created
# (via config.preprocessor_pref=None). If True, observations will arrive # (via config._disable_preprocessor_api=True). If True, observations
# in model as they are returned by the env. # will arrive in model as they are returned by the env.
"_no_preprocessing": False, "_disable_preprocessor_api": False,
# === Built-in options === # === Built-in options ===
# FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py # FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py

View file

@ -217,12 +217,12 @@ class ModelV2:
else: else:
restored = input_dict.copy() restored = input_dict.copy()
# No Preprocessor used: `config.preprocessor_pref`=None. # No Preprocessor used: `config._disable_preprocessor_api`=True.
# TODO: This is unnecessary for when no preprocessor is used. # TODO: This is unnecessary for when no preprocessor is used.
# Obs are not flat then anymore. However, we'll keep this # Obs are not flat then anymore. However, we'll keep this
# here for backward-compatibility until Preprocessors have # here for backward-compatibility until Preprocessors have
# been fully deprecated. # been fully deprecated.
if self.model_config.get("_no_preprocessing"): if self.model_config.get("_disable_preprocessor_api"):
restored["obs_flat"] = input_dict["obs"] restored["obs_flat"] = input_dict["obs"]
# Input to this Model went through a Preprocessor. # Input to this Model went through a Preprocessor.
# Generate extra keys: "obs_flat" (vs "obs", which will hold the # Generate extra keys: "obs_flat" (vs "obs", which will hold the

View file

@ -1,7 +1,7 @@
from typing import List from typing import List
from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.framework import TensorType, TensorStructType from ray.rllib.utils.typing import TensorType, TensorStructType
@PublicAPI @PublicAPI

View file

@ -3,14 +3,57 @@ from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np import numpy as np
import unittest import unittest
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.preprocessors import DictFlatteningPreprocessor, \ from ray.rllib.models.preprocessors import DictFlatteningPreprocessor, \
get_preprocessor, NoPreprocessor, TupleFlatteningPreprocessor, \ get_preprocessor, NoPreprocessor, TupleFlatteningPreprocessor, \
OneHotPreprocessor, AtariRamPreprocessor, GenericPixelPreprocessor OneHotPreprocessor, AtariRamPreprocessor, GenericPixelPreprocessor
from ray.rllib.utils.test_utils import check from ray.rllib.utils.test_utils import check, check_compute_single_action, \
framework_iterator
class TestPreprocessors(unittest.TestCase): class TestPreprocessors(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_preprocessing_disabled(self):
config = ppo.DEFAULT_CONFIG.copy()
config["env"] = "ray.rllib.examples.env.random_env.RandomEnv"
config["env_config"] = {
"config": {
"observation_space": Dict({
"a": Discrete(5),
"b": Dict({
"ba": Discrete(4),
"bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32)
}),
"c": Tuple((MultiDiscrete([2, 3]), Discrete(1))),
"d": Box(-1.0, 1.0, (1, ), dtype=np.int32),
}),
},
}
# Set this to True to enforce no preprocessors being used.
# Complex observations now arrive directly in the model as
# structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]}
# for obs-space=Dict(a=..., b=Tuple(..., ...)).
config["_disable_preprocessor_api"] = True
num_iterations = 1
# Only supported for tf so far.
for _ in framework_iterator(config):
trainer = ppo.PPOTrainer(config=config)
for i in range(num_iterations):
print(trainer.train())
check_compute_single_action(trainer)
trainer.stop()
def test_gym_preprocessors(self): def test_gym_preprocessors(self):
p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0")) p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0"))
self.assertEqual(type(p1), NoPreprocessor) self.assertEqual(type(p1), NoPreprocessor)

View file

@ -133,7 +133,7 @@ class ComplexInputNetwork(TFModelV2):
cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
outs.append(cnn_out) outs.append(cnn_out)
elif i in self.one_hot: elif i in self.one_hot:
if component.dtype in [tf.int32, tf.int64, tf.uint8]: if "int" in component.dtype.name:
outs.append( outs.append(
one_hot(component, self.flattened_input_space[i])) one_hot(component, self.flattened_input_space[i]))
else: else:

View file

@ -546,6 +546,7 @@ class DynamicTFPolicy(TFPolicy):
# Skip action dist inputs placeholder (do later). # Skip action dist inputs placeholder (do later).
elif view_col == SampleBatch.ACTION_DIST_INPUTS: elif view_col == SampleBatch.ACTION_DIST_INPUTS:
continue continue
# This is a tower, input placeholders already exist.
elif view_col in existing_inputs: elif view_col in existing_inputs:
input_dict[view_col] = existing_inputs[view_col] input_dict[view_col] = existing_inputs[view_col]
# All others. # All others.
@ -554,10 +555,15 @@ class DynamicTFPolicy(TFPolicy):
if view_req.used_for_training: if view_req.used_for_training:
# Create a +time-axis placeholder if the shift is not an # Create a +time-axis placeholder if the shift is not an
# int (range or list of ints). # int (range or list of ints).
flatten = view_col not in [
SampleBatch.OBS, SampleBatch.NEXT_OBS] or \
not self.config["_disable_preprocessor_api"]
input_dict[view_col] = get_placeholder( input_dict[view_col] = get_placeholder(
space=view_req.space, space=view_req.space,
name=view_col, name=view_col,
time_axis=time_axis) time_axis=time_axis,
flatten=flatten,
)
dummy_batch = self._get_dummy_batch_from_view_requirements( dummy_batch = self._get_dummy_batch_from_view_requirements(
batch_size=32) batch_size=32)

View file

@ -5,6 +5,7 @@ It supports both traced and non-traced eager execution modes."""
import functools import functools
import logging import logging
import threading import threading
import tree # pip install dm_tree
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from ray.util.debug import log_once from ray.util.debug import log_once
@ -425,10 +426,12 @@ def build_eager_tf_policy(
if not tf1.executing_eagerly(): if not tf1.executing_eagerly():
tf1.enable_eager_execution() tf1.enable_eager_execution()
input_dict = { input_dict = SampleBatch(
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), {
"is_training": tf.constant(False), SampleBatch.CUR_OBS: tree.map_structure(
} lambda s: tf.convert_to_tensor(s), obs_batch),
},
_is_training=tf.constant(False))
if prev_action_batch is not None: if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \ input_dict[SampleBatch.PREV_ACTIONS] = \
tf.convert_to_tensor(prev_action_batch) tf.convert_to_tensor(prev_action_batch)
@ -478,7 +481,7 @@ def build_eager_tf_policy(
self._is_training = False self._is_training = False
self._state_in = state_batches or [] self._state_in = state_batches or []
# Calculate RNN sequence lengths. # Calculate RNN sequence lengths.
batch_size = input_dict[SampleBatch.CUR_OBS].shape[0] batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \ seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \
else None else None

View file

@ -4,6 +4,7 @@ import gym
from gym.spaces import Box from gym.spaces import Box
import logging import logging
import numpy as np import numpy as np
import tree # pip install dm_tree
from typing import Dict, List, Optional, TYPE_CHECKING from typing import Dict, List, Optional, TYPE_CHECKING
from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.catalog import ModelCatalog
@ -17,7 +18,7 @@ from ray.rllib.utils.spaces.space_utils import clip_action, \
get_base_struct_from_space, get_dummy_batch_for_space, unbatch, \ get_base_struct_from_space, get_dummy_batch_for_space, unbatch, \
unsquash_action unsquash_action
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
TensorType, TrainerConfigDict, Tuple, Union TensorType, TensorStructType, TrainerConfigDict, Tuple, Union
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch() torch, _ = try_import_torch()
@ -132,10 +133,12 @@ class Policy(metaclass=ABCMeta):
@DeveloperAPI @DeveloperAPI
def compute_actions( def compute_actions(
self, self,
obs_batch: Union[List[TensorType], TensorType], obs_batch: Union[List[TensorStructType], TensorStructType],
state_batches: Optional[List[TensorType]] = None, state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Union[List[TensorType], TensorType] = None, prev_action_batch: Union[List[TensorStructType],
prev_reward_batch: Union[List[TensorType], TensorType] = None, TensorStructType] = None,
prev_reward_batch: Union[List[TensorStructType],
TensorStructType] = None,
info_batch: Optional[Dict[str, list]] = None, info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None, episodes: Optional[List["MultiAgentEpisode"]] = None,
explore: Optional[bool] = None, explore: Optional[bool] = None,
@ -145,14 +148,10 @@ class Policy(metaclass=ABCMeta):
"""Computes actions for the current policy. """Computes actions for the current policy.
Args: Args:
obs_batch (Union[List[TensorType], TensorType]): Batch of obs_batch: Batch of observations.
observations. state_batches: List of RNN state input batches, if any.
state_batches (Optional[List[TensorType]]): List of RNN state input prev_action_batch: Batch of previous action values.
batches, if any. prev_reward_batch: Batch of previous rewards.
prev_action_batch (Union[List[TensorType], TensorType]): Batch of
previous action values.
prev_reward_batch (Union[List[TensorType], TensorType]): Batch of
previous rewards.
info_batch (Optional[Dict[str, list]]): Batch of info objects. info_batch (Optional[Dict[str, list]]): Batch of info objects.
episodes (Optional[List[MultiAgentEpisode]] ): List of episodes (Optional[List[MultiAgentEpisode]] ): List of
MultiAgentEpisode, one for each obs in obs_batch. This provides MultiAgentEpisode, one for each obs in obs_batch. This provides
@ -181,10 +180,10 @@ class Policy(metaclass=ABCMeta):
@DeveloperAPI @DeveloperAPI
def compute_single_action( def compute_single_action(
self, self,
obs: TensorType, obs: TensorStructType,
state: Optional[List[TensorType]] = None, state: Optional[List[TensorType]] = None,
prev_action: Optional[TensorType] = None, prev_action: Optional[TensorStructType] = None,
prev_reward: Optional[TensorType] = None, prev_reward: Optional[TensorStructType] = None,
info: dict = None, info: dict = None,
episode: Optional["MultiAgentEpisode"] = None, episode: Optional["MultiAgentEpisode"] = None,
clip_actions: bool = None, clip_actions: bool = None,
@ -192,37 +191,34 @@ class Policy(metaclass=ABCMeta):
timestep: Optional[int] = None, timestep: Optional[int] = None,
unsquash_actions: bool = None, unsquash_actions: bool = None,
**kwargs) -> \ **kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
"""Unbatched version of compute_actions. """Unbatched version of compute_actions.
Args: Args:
obs (TensorType): Single observation. obs: Single observation.
state (Optional[List[TensorType]]): List of RNN state inputs, if state: List of RNN state inputs, if any.
any. prev_action: Previous action value, if any.
prev_action (Optional[TensorType]): Previous action value, if any. prev_reward: Previous reward, if any.
prev_reward (Optional[TensorType]): Previous reward, if any.
info (dict): Info object, if any. info (dict): Info object, if any.
episode (Optional[MultiAgentEpisode]): this provides access to all episode: this provides access to all
of the internal episode state, which may be useful for of the internal episode state, which may be useful for
model-based or multi-agent algorithms. model-based or multi-agent algorithms.
unsquash_actions (bool): Should actions be unsquashed according to unsquash_actions: Should actions be unsquashed according to
the Policy's action space? the Policy's action space?
clip_actions (bool): Should actions be clipped according to the clip_actions: Should actions be clipped according to the
Policy's action space? Policy's action space?
explore (Optional[bool]): Whether to pick an exploitation or explore: Whether to pick an exploitation or
exploration action exploration action
(default: None -> use self.config["explore"]). (default: None -> use self.config["explore"]).
timestep (Optional[int]): The current (sampling) time step. timestep: The current (sampling) time step.
Keyword Args: Keyword Args:
kwargs: Forward compatibility. kwargs: Forward compatibility.
Returns: Returns:
Tuple: - actions: Single action.
- actions (TensorType): Single action. - state_outs: List of RNN state outputs, if any.
- state_outs (List[TensorType]): List of RNN state outputs, - info: Dictionary of extra features, if any.
if any.
- info (dict): Dictionary of extra features, if any.
""" """
# If policy works in normalized space, we should unsquash the action. # If policy works in normalized space, we should unsquash the action.
# Use value of config.normalize_actions, if None. # Use value of config.normalize_actions, if None.
@ -253,7 +249,7 @@ class Policy(metaclass=ABCMeta):
] ]
out = self.compute_actions( out = self.compute_actions(
[obs], tree.map_structure(lambda s: np.array([s]), obs),
state_batch, state_batch,
prev_action_batch=prev_action_batch, prev_action_batch=prev_action_batch,
prev_reward_batch=prev_reward_batch, prev_reward_batch=prev_reward_batch,
@ -805,7 +801,8 @@ class Policy(metaclass=ABCMeta):
for key in all_accessed_keys: for key in all_accessed_keys:
if key not in self.view_requirements and \ if key not in self.view_requirements and \
key != SampleBatch.SEQ_LENS: key != SampleBatch.SEQ_LENS:
self.view_requirements[key] = ViewRequirement() self.view_requirements[key] = ViewRequirement(
used_for_compute_actions=False)
if self._loss: if self._loss:
# Tag those only needed for post-processing (with some # Tag those only needed for post-processing (with some
# exceptions). # exceptions).
@ -852,7 +849,7 @@ class Policy(metaclass=ABCMeta):
""" """
ret = {} ret = {}
for view_col, view_req in self.view_requirements.items(): for view_col, view_req in self.view_requirements.items():
if self.config["preprocessor_pref"] is not None and \ if not self.config["_disable_preprocessor_api"] and \
isinstance(view_req.space, isinstance(view_req.space,
(gym.spaces.Dict, gym.spaces.Tuple)): (gym.spaces.Dict, gym.spaces.Tuple)):
_, shape = ModelCatalog.get_action_shape( _, shape = ModelCatalog.get_action_shape(

View file

@ -364,6 +364,9 @@ class SampleBatch(dict):
for i, p in enumerate(path): for i, p in enumerate(path):
if i == len(path) - 1: if i == len(path) - 1:
curr[p] = value[permutation] curr[p] = value[permutation]
# Translate into list (tuples are immutable).
if isinstance(curr[p], tuple):
curr[p] = list(curr[p])
curr = curr[p] curr = curr[p]
tree.map_structure_with_path(_permutate_in_place, self) tree.map_structure_with_path(_permutate_in_place, self)

View file

@ -4,6 +4,7 @@ import logging
import math import math
import numpy as np import numpy as np
import os import os
import tree # pip install dm_tree
from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
import ray import ray
@ -435,15 +436,16 @@ class TFPolicy(Policy):
fetched = builder.get(to_fetch) fetched = builder.get(to_fetch)
# Update our global timestep by the batch size. # Update our global timestep by the batch size.
self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \ self.global_timestep += \
else obs_batch.shape[0] len(obs_batch) if isinstance(obs_batch, list) \
else tree.flatten(obs_batch)[0].shape[0]
return fetched return fetched
@override(Policy) @override(Policy)
def compute_actions_from_input_dict( def compute_actions_from_input_dict(
self, self,
input_dict: Dict[str, TensorType], input_dict: Union[SampleBatch, Dict[str, TensorType]],
explore: bool = None, explore: bool = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None, episodes: Optional[List["MultiAgentEpisode"]] = None,
@ -464,6 +466,7 @@ class TFPolicy(Policy):
# Update our global timestep by the batch size. # Update our global timestep by the batch size.
self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \ self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
else len(input_dict) if isinstance(input_dict, SampleBatch) \
else obs_batch.shape[0] else obs_batch.shape[0]
return fetched return fetched
@ -965,7 +968,11 @@ class TFPolicy(Policy):
if hasattr(self, "_input_dict"): if hasattr(self, "_input_dict"):
for key, value in input_dict.items(): for key, value in input_dict.items():
if key in self._input_dict: if key in self._input_dict:
builder.add_feed_dict({self._input_dict[key]: value}) # Handle complex/nested spaces as well.
tree.map_structure(
lambda k, v: builder.add_feed_dict({k: v}),
self._input_dict[key], value,
)
# For policies that inherit directly from TFPolicy. # For policies that inherit directly from TFPolicy.
else: else:
builder.add_feed_dict({ builder.add_feed_dict({
@ -1004,7 +1011,10 @@ class TFPolicy(Policy):
"Must pass in RNN state batches for placeholders {}, " "Must pass in RNN state batches for placeholders {}, "
"got {}".format(self._state_inputs, state_batches)) "got {}".format(self._state_inputs, state_batches))
builder.add_feed_dict({self._obs_input: obs_batch}) tree.map_structure(
lambda k, v: builder.add_feed_dict({k: v}),
self._obs_input, obs_batch,
)
if state_batches: if state_batches:
builder.add_feed_dict({ builder.add_feed_dict({
self._seq_lens: np.ones(len(obs_batch)) self._seq_lens: np.ones(len(obs_batch))
@ -1106,8 +1116,12 @@ class TFPolicy(Policy):
# Build the feed dict from the batch. # Build the feed dict from the batch.
feed_dict = {} feed_dict = {}
for key, placeholder in self._loss_input_dict.items(): for key, placeholders in self._loss_input_dict.items():
feed_dict[placeholder] = train_batch[key] tree.map_structure(
lambda ph, v: feed_dict.__setitem__(ph, v),
placeholders,
train_batch[key],
)
state_keys = [ state_keys = [
"state_in_{}".format(i) for i in range(len(self._state_inputs)) "state_in_{}".format(i) for i in range(len(self._state_inputs))

View file

@ -26,7 +26,7 @@ from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \ from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor convert_to_torch_tensor
from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \ from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \
TrainerConfigDict TensorStructType, TrainerConfigDict
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode # noqa from ray.rllib.evaluation import MultiAgentEpisode # noqa
@ -246,23 +246,24 @@ class TorchPolicy(Policy):
@DeveloperAPI @DeveloperAPI
def compute_actions( def compute_actions(
self, self,
obs_batch: Union[List[TensorType], TensorType], obs_batch: Union[List[TensorStructType], TensorStructType],
state_batches: Optional[List[TensorType]] = None, state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Union[List[TensorType], TensorType] = None, prev_action_batch: Union[List[TensorStructType],
prev_reward_batch: Union[List[TensorType], TensorType] = None, TensorStructType] = None,
prev_reward_batch: Union[List[TensorStructType],
TensorStructType] = None,
info_batch: Optional[Dict[str, list]] = None, info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None, episodes: Optional[List["MultiAgentEpisode"]] = None,
explore: Optional[bool] = None, explore: Optional[bool] = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
**kwargs) -> \ **kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
with torch.no_grad(): with torch.no_grad():
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
input_dict = self._lazy_tensor_dict( input_dict = self._lazy_tensor_dict({
SampleBatch({ SampleBatch.CUR_OBS: obs_batch
SampleBatch.CUR_OBS: np.asarray(obs_batch), })
}))
if prev_action_batch is not None: if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \ input_dict[SampleBatch.PREV_ACTIONS] = \
np.asarray(prev_action_batch) np.asarray(prev_action_batch)
@ -405,13 +406,13 @@ class TorchPolicy(Policy):
@DeveloperAPI @DeveloperAPI
def compute_log_likelihoods( def compute_log_likelihoods(
self, self,
actions: Union[List[TensorType], TensorType], actions: Union[List[TensorStructType], TensorStructType],
obs_batch: Union[List[TensorType], TensorType], obs_batch: Union[List[TensorStructType], TensorStructType],
state_batches: Optional[List[TensorType]] = None, state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType], prev_action_batch: Optional[Union[List[TensorStructType],
TensorType]] = None, TensorStructType]] = None,
prev_reward_batch: Optional[Union[List[TensorType], prev_reward_batch: Optional[Union[List[TensorStructType],
TensorType]] = None, TensorStructType]] = None,
actions_normalized: bool = True, actions_normalized: bool = True,
) -> TensorType: ) -> TensorType:

View file

@ -35,10 +35,11 @@ class Filter:
class NoFilter(Filter): class NoFilter(Filter):
is_concurrent = True is_concurrent = True
def __init__(self, *args):
pass
def __call__(self, x, update=True): def __call__(self, x, update=True):
# Process no further if already np.ndarray, dict, or tuple.
if isinstance(x, (np.ndarray, dict, tuple)):
return x
try: try:
return np.asarray(x) return np.asarray(x)
except Exception: except Exception:

View file

@ -5,16 +5,10 @@ import sys
from typing import Any, Optional from typing import Any, Optional
from ray.rllib.utils.annotations import Deprecated from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.typing import TensorStructType, TensorShape, TensorType from ray.rllib.utils.typing import TensorShape, TensorType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Represents a generic tensor type.
TensorType = TensorType
# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
TensorStructType = TensorStructType
def try_import_jax(error=False): def try_import_jax(error=False):
"""Tries importing JAX and FLAX and returns both modules (or Nones). """Tries importing JAX and FLAX and returns both modules (or Nones).
@ -199,36 +193,36 @@ def _torch_stubs():
return None, nn return None, nn
def get_variable(value, def get_variable(value: Any,
framework: str = "tf", framework: str = "tf",
trainable: bool = False, trainable: bool = False,
tf_name: str = "unnamed-variable", tf_name: str = "unnamed-variable",
torch_tensor: bool = False, torch_tensor: bool = False,
device: Optional[str] = None, device: Optional[str] = None,
shape: Optional[TensorShape] = None, shape: Optional[TensorShape] = None,
dtype: Optional[Any] = None): dtype: Optional[TensorType] = None) -> Any:
""" """
Args: Args:
value (any): The initial value to use. In the non-tf case, this will value: The initial value to use. In the non-tf case, this will
be returned as is. In the tf case, this could be a tf-Initializer be returned as is. In the tf case, this could be a tf-Initializer
object. object.
framework (str): One of "tf", "torch", or None. framework: One of "tf", "torch", or None.
trainable (bool): Whether the generated variable should be trainable: Whether the generated variable should be
trainable (tf)/require_grad (torch) or not (default: False). trainable (tf)/require_grad (torch) or not (default: False).
tf_name (str): For framework="tf": An optional name for the tf_name: For framework="tf": An optional name for the
tf.Variable. tf.Variable.
torch_tensor (bool): For framework="torch": Whether to actually create torch_tensor: For framework="torch": Whether to actually create
a torch.tensor, or just a python value (default). a torch.tensor, or just a python value (default).
device (Optional[torch.Device]): An optional torch device to use for device: An optional torch device to use for
the created torch tensor. the created torch tensor.
shape (Optional[TensorShape]): An optional shape to use iff `value` shape: An optional shape to use iff `value`
does not have any (e.g. if it's an initializer w/o explicit value). does not have any (e.g. if it's an initializer w/o explicit value).
dtype (Optional[TensorType]): An optional dtype to use iff `value` does dtype: An optional dtype to use iff `value` does
not have any (e.g. if it's an initializer w/o explicit value). not have any (e.g. if it's an initializer w/o explicit value).
This should always be a numpy dtype (e.g. np.float32, np.int64). This should always be a numpy dtype (e.g. np.float32, np.int64).
Returns: Returns:
any: A framework-specific variable (tf.Variable, torch.tensor, or A framework-specific variable (tf.Variable, torch.tensor, or
python primitive). python primitive).
""" """
if framework in ["tf2", "tf", "tfe"]: if framework in ["tf2", "tf", "tfe"]:

View file

@ -4,6 +4,7 @@ import numpy as np
import tree # pip install dm_tree import tree # pip install dm_tree
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import TensorStructType, TensorType from ray.rllib.utils.typing import TensorStructType, TensorType
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
@ -56,12 +57,26 @@ def get_gpu_devices():
return [d.name for d in devices if "GPU" in d.device_type] return [d.name for d in devices if "GPU" in d.device_type]
def get_placeholder(*, space=None, value=None, name=None, time_axis=False): def get_placeholder(*,
space=None,
value=None,
name=None,
time_axis=False,
flatten=True):
from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.catalog import ModelCatalog
if space is not None: if space is not None:
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
return ModelCatalog.get_action_placeholder(space, None) if flatten:
return ModelCatalog.get_action_placeholder(space, None)
else:
return tree.map_structure_with_path(
lambda path, component: get_placeholder(
space=component,
name=name + "." + ".".join([str(p) for p in path]),
),
get_base_struct_from_space(space),
)
return tf1.placeholder( return tf1.placeholder(
shape=(None, ) + ((None, ) if time_axis else ()) + space.shape, shape=(None, ) + ((None, ) if time_axis else ()) + space.shape,
dtype=tf.float32 if space.dtype == np.float64 else space.dtype, dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
@ -138,10 +153,13 @@ def zero_logps_from_actions(actions: TensorStructType) -> TensorType:
def one_hot(x, space): def one_hot(x, space):
if isinstance(space, Discrete): if isinstance(space, Discrete):
return tf.one_hot(x, space.n) return tf.one_hot(x, space.n, dtype=tf.float32)
elif isinstance(space, MultiDiscrete): elif isinstance(space, MultiDiscrete):
return tf.concat( return tf.concat(
[tf.one_hot(x[:, i], n) for i, n in enumerate(space.nvec)], [
tf.one_hot(x[:, i], n, dtype=tf.float32)
for i, n in enumerate(space.nvec)
],
axis=-1) axis=-1)
else: else:
raise ValueError("Unsupported space for `one_hot`: {}".format(space)) raise ValueError("Unsupported space for `one_hot`: {}".format(space))
@ -200,11 +218,12 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
def make_wrapper(fn): def make_wrapper(fn):
# Static-graph mode: Create placeholders and make a session call each # Static-graph mode: Create placeholders and make a session call each
# time the wrapped function is called. Return this session call's # time the wrapped function is called. Returns the output of this
# outputs. # session call.
if session_or_none is not None: if session_or_none is not None:
args_placeholders = [] args_placeholders = []
kwargs_placeholders = {} kwargs_placeholders = {}
symbolic_out = [None] symbolic_out = [None]
def call(*args, **kwargs): def call(*args, **kwargs):
@ -215,40 +234,42 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
else: else:
args_flat.append(a) args_flat.append(a)
args = args_flat args = args_flat
# We have not built any placeholders yet: Do this once here,
# then reuse the same placeholders each time we call this
# function again.
if symbolic_out[0] is None: if symbolic_out[0] is None:
with session_or_none.graph.as_default(): with session_or_none.graph.as_default():
for i, v in enumerate(args):
def _create_placeholders(path, value):
if dynamic_shape: if dynamic_shape:
if len(v.shape) > 0: if len(value.shape) > 0:
shape = (None, ) + v.shape[1:] shape = (None, ) + value.shape[1:]
else: else:
shape = () shape = ()
else: else:
shape = v.shape shape = value.shape
args_placeholders.append( return tf1.placeholder(
tf1.placeholder( dtype=value.dtype,
dtype=v.dtype, shape=shape,
shape=shape, name=".".join([str(p) for p in path]),
name="arg_{}".format(i))) )
for k, v in kwargs.items():
if dynamic_shape: placeholders = tree.map_structure_with_path(
if len(v.shape) > 0: _create_placeholders, args)
shape = (None, ) + v.shape[1:] for ph in tree.flatten(placeholders):
else: args_placeholders.append(ph)
shape = ()
else: placeholders = tree.map_structure_with_path(
shape = v.shape _create_placeholders, kwargs)
kwargs_placeholders[k] = \ for k, ph in placeholders.items():
tf1.placeholder( kwargs_placeholders[k] = ph
dtype=v.dtype,
shape=shape,
name="kwarg_{}".format(k))
symbolic_out[0] = fn(*args_placeholders, symbolic_out[0] = fn(*args_placeholders,
**kwargs_placeholders) **kwargs_placeholders)
feed_dict = dict(zip(args_placeholders, args)) feed_dict = dict(zip(args_placeholders, tree.flatten(args)))
feed_dict.update( tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v),
{kwargs_placeholders[k]: kwargs[k] kwargs_placeholders, kwargs)
for k in kwargs.keys()})
ret = session_or_none.run(symbolic_out[0], feed_dict) ret = session_or_none.run(symbolic_out[0], feed_dict)
return ret return ret