[RLlib] Layout of Trajectory View API (new class: Trajectory; not used yet). (#9269)

This commit is contained in:
Sven Mika 2020-07-14 04:27:49 +02:00 committed by GitHub
parent 222635b63f
commit 03ab86567f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 485 additions and 43 deletions

View file

@ -213,6 +213,13 @@ COMMON_CONFIG = {
# Use a background thread for sampling (slightly off-policy, usually not # Use a background thread for sampling (slightly off-policy, usually not
# advisable to turn on unless your env specifically requires it). # advisable to turn on unless your env specifically requires it).
"sample_async": False, "sample_async": False,
# Experimental flag to speed up sampling and use "trajectory views" as
# generic ModelV2 `input_dicts` that can be requested by the model to
# contain different information on the ongoing episode.
# NOTE: Only supported for PyTorch so far.
"_use_trajectory_view_api": False,
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter". # Element-wise observation filter, either "NoFilter" or "MeanStdFilter".
"observation_filter": "NoFilter", "observation_filter": "NoFilter",
# Whether to synchronize the statistics of remote filters. # Whether to synchronize the statistics of remote filters.
@ -1057,6 +1064,11 @@ class Trainer(Trainable):
@staticmethod @staticmethod
def _validate_config(config: PartialTrainerConfigDict): def _validate_config(config: PartialTrainerConfigDict):
if config.get("_use_trajectory_view_api") and \
config.get("framework") != "torch":
raise ValueError(
"`_use_trajectory_view_api` only supported for PyTorch so "
"far!")
if "policy_graphs" in config["multiagent"]: if "policy_graphs" in config["multiagent"]:
deprecation_warning("policy_graphs", "policies") deprecation_warning("policy_graphs", "policies")
# Backwards compatibility. # Backwards compatibility.

View file

@ -496,7 +496,9 @@ class RolloutWorker(ParallelIteratorWorker):
blackhole_outputs="simulation" in input_evaluation, blackhole_outputs="simulation" in input_evaluation,
soft_horizon=soft_horizon, soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end, no_done_at_end=no_done_at_end,
observation_fn=observation_fn) observation_fn=observation_fn,
_use_trajectory_view_api=policy_config.get(
"_use_trajectory_view_api", False))
# Start the Sampler thread. # Start the Sampler thread.
self.sampler.start() self.sampler.start()
else: else:
@ -516,7 +518,9 @@ class RolloutWorker(ParallelIteratorWorker):
clip_actions=clip_actions, clip_actions=clip_actions,
soft_horizon=soft_horizon, soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end, no_done_at_end=no_done_at_end,
observation_fn=observation_fn) observation_fn=observation_fn,
_use_trajectory_view_api=policy_config.get(
"_use_trajectory_view_api", False))
self.input_reader: InputReader = input_creator(self.io_context) self.input_reader: InputReader = input_creator(self.io_context)
self.output_writer: OutputWriter = output_creator(self.io_context) self.output_writer: OutputWriter = output_creator(self.io_context)
@ -561,7 +565,8 @@ class RolloutWorker(ParallelIteratorWorker):
batch = self.input_reader.next() batch = self.input_reader.next()
steps_so_far += batch.count steps_so_far += batch.count
batches.append(batch) batches.append(batch)
batch = batches[0].concat_samples(batches) batch = batches[0].concat_samples(batches) if len(batches) > 1 else \
batches[0]
self.callbacks.on_sample_end(worker=self, samples=batch) self.callbacks.on_sample_end(worker=self, samples=batch)

View file

@ -5,8 +5,8 @@ import numpy as np
import queue import queue
import threading import threading
import time import time
from typing import List, Dict, Callable, Set, Tuple, Any, Iterable, Union, \ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, \
TYPE_CHECKING TYPE_CHECKING, Union
from ray.util.debug import log_once from ray.util.debug import log_once
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import MultiAgentEpisode
@ -113,7 +113,8 @@ class SyncSampler(SamplerInput):
clip_actions: bool = True, clip_actions: bool = True,
soft_horizon: bool = False, soft_horizon: bool = False,
no_done_at_end: bool = False, no_done_at_end: bool = False,
observation_fn: "ObservationFunction" = None): observation_fn: "ObservationFunction" = None,
_use_trajectory_view_api: bool = False):
"""Initializes a SyncSampler object. """Initializes a SyncSampler object.
Args: Args:
@ -150,6 +151,9 @@ class SyncSampler(SamplerInput):
observation_fn (Optional[ObservationFunction]): Optional observation_fn (Optional[ObservationFunction]): Optional
multi-agent observation func to use for preprocessing multi-agent observation func to use for preprocessing
observations. observations.
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` to make generic trajectory views
available to Models. Default: False.
""" """
self.base_env = BaseEnv.to_base_env(env) self.base_env = BaseEnv.to_base_env(env)
@ -167,7 +171,8 @@ class SyncSampler(SamplerInput):
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon, self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
self.preprocessors, self.obs_filters, clip_rewards, clip_actions, self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
pack_multiple_episodes_in_batch, callbacks, tf_sess, pack_multiple_episodes_in_batch, callbacks, tf_sess,
self.perf_stats, soft_horizon, no_done_at_end, observation_fn) self.perf_stats, soft_horizon, no_done_at_end, observation_fn,
_use_trajectory_view_api)
self.metrics_queue = queue.Queue() self.metrics_queue = queue.Queue()
@override(SamplerInput) @override(SamplerInput)
@ -227,7 +232,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
blackhole_outputs: bool = False, blackhole_outputs: bool = False,
soft_horizon: bool = False, soft_horizon: bool = False,
no_done_at_end: bool = False, no_done_at_end: bool = False,
observation_fn: "ObservationFunction" = None): observation_fn: "ObservationFunction" = None,
_use_trajectory_view_api: bool = False):
"""Initializes a AsyncSampler object. """Initializes a AsyncSampler object.
Args: Args:
@ -266,6 +272,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
observation_fn (Optional[ObservationFunction]): Optional observation_fn (Optional[ObservationFunction]): Optional
multi-agent observation func to use for preprocessing multi-agent observation func to use for preprocessing
observations. observations.
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` to make generic trajectory views
available to Models. Default: False.
""" """
for _, f in obs_filters.items(): for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \ assert getattr(f, "is_concurrent", False), \
@ -294,6 +303,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.perf_stats = _PerfStats() self.perf_stats = _PerfStats()
self.shutdown = False self.shutdown = False
self.observation_fn = observation_fn self.observation_fn = observation_fn
self._use_trajectory_view_api = _use_trajectory_view_api
@override(threading.Thread) @override(threading.Thread)
def run(self): def run(self):
@ -317,7 +327,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.preprocessors, self.obs_filters, self.clip_rewards, self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack_multiple_episodes_in_batch, self.clip_actions, self.pack_multiple_episodes_in_batch,
self.callbacks, self.tf_sess, self.perf_stats, self.soft_horizon, self.callbacks, self.tf_sess, self.perf_stats, self.soft_horizon,
self.no_done_at_end, self.observation_fn) self.no_done_at_end, self.observation_fn,
self._use_trajectory_view_api)
while not self.shutdown: while not self.shutdown:
# The timeout variable exists because apparently, if one worker # The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is # dies, the other workers won't die with it, unless the timeout is
@ -362,24 +373,34 @@ class AsyncSampler(threading.Thread, SamplerInput):
return extra return extra
def _env_runner( def _env_runner(worker: "RolloutWorker",
worker: "RolloutWorker", base_env: BaseEnv, base_env: BaseEnv,
extra_batch_callback: Callable[[SampleBatchType], None], policies, extra_batch_callback: Callable[[SampleBatchType], None],
policy_mapping_fn: Callable[[AgentID], PolicyID], policies: Dict[PolicyID, Policy],
rollout_fragment_length: int, horizon: int, policy_mapping_fn: Callable[[AgentID], PolicyID],
preprocessors: Dict[PolicyID, Preprocessor], rollout_fragment_length: int,
obs_filters: Dict[PolicyID, Filter], clip_rewards: bool, horizon: int,
clip_actions: bool, pack_multiple_episodes_in_batch: bool, preprocessors: Dict[PolicyID, Preprocessor],
callbacks: "DefaultCallbacks", tf_sess, perf_stats: _PerfStats, obs_filters: Dict[PolicyID, Filter],
soft_horizon: bool, no_done_at_end: bool, clip_rewards: bool,
observation_fn: "ObservationFunction") -> Iterable[SampleBatchType]: clip_actions: bool,
pack_multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
tf_sess: Optional["tf.Session"],
perf_stats: _PerfStats,
soft_horizon: bool,
no_done_at_end: bool,
observation_fn: "ObservationFunction",
_use_trajectory_view_api: bool = False
) -> Iterable[SampleBatchType]:
"""This implements the common experience collection logic. """This implements the common experience collection logic.
Args: Args:
worker (RolloutWorker): Reference to the current rollout worker. worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv. base_env (BaseEnv): Env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to. extra_batch_callback (fn): function to send extra batch data to.
policies (dict): Map of policy ids to Policy instances. policies (Dict[PolicyID, Policy]): Map of policy ids to Policy
instances.
policy_mapping_fn (func): Function that maps agent ids to policy ids. policy_mapping_fn (func): Function that maps agent ids to policy ids.
This is called when an agent first enters the environment. The This is called when an agent first enters the environment. The
agent is then "bound" to the returned policy for the episode. agent is then "bound" to the returned policy for the episode.
@ -406,6 +427,9 @@ def _env_runner(
and instead record done=False. and instead record done=False.
observation_fn (ObservationFunction): Optional multi-agent observation_fn (ObservationFunction): Optional multi-agent
observation func to use for preprocessing observations. observation func to use for preprocessing observations.
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` to make generic trajectory views
available to Models. Default: False.
Yields: Yields:
rollout (SampleBatch): Object containing state, action, reward, rollout (SampleBatch): Object containing state, action, reward,
@ -508,7 +532,8 @@ def _env_runner(
callbacks=callbacks, callbacks=callbacks,
soft_horizon=soft_horizon, soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end, no_done_at_end=no_done_at_end,
observation_fn=observation_fn) observation_fn=observation_fn,
_use_trajectory_view_api=_use_trajectory_view_api)
perf_stats.processing_time += time.time() - t1 perf_stats.processing_time += time.time() - t1
for o in outputs: for o in outputs:
yield o yield o
@ -520,7 +545,8 @@ def _env_runner(
to_eval=to_eval, to_eval=to_eval,
policies=policies, policies=policies,
active_episodes=active_episodes, active_episodes=active_episodes,
tf_sess=tf_sess) tf_sess=tf_sess,
_use_trajectory_view_api=_use_trajectory_view_api)
perf_stats.inference_time += time.time() - t2 perf_stats.inference_time += time.time() - t2
# Process results and update episode state. # Process results and update episode state.
@ -533,7 +559,8 @@ def _env_runner(
active_envs=active_envs, active_envs=active_envs,
off_policy_actions=off_policy_actions, off_policy_actions=off_policy_actions,
policies=policies, policies=policies,
clip_actions=clip_actions) clip_actions=clip_actions,
_use_trajectory_view_api=_use_trajectory_view_api)
perf_stats.processing_time += time.time() - t3 perf_stats.processing_time += time.time() - t3
# Return computed actions to ready envs. We also send to envs that have # Return computed actions to ready envs. We also send to envs that have
@ -556,7 +583,8 @@ def _process_observations(
obs_filters: Dict[PolicyID, Filter], rollout_fragment_length: int, obs_filters: Dict[PolicyID, Filter], rollout_fragment_length: int,
pack_multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", pack_multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks",
soft_horizon: bool, no_done_at_end: bool, soft_horizon: bool, no_done_at_end: bool,
observation_fn: "ObservationFunction" observation_fn: "ObservationFunction",
_use_trajectory_view_api: bool = False
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[ ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
RolloutMetrics, SampleBatchType]]]: RolloutMetrics, SampleBatchType]]]:
"""Record new data from the environment and prepare for policy evaluation. """Record new data from the environment and prepare for policy evaluation.
@ -595,6 +623,9 @@ def _process_observations(
and instead record done=False. and instead record done=False.
observation_fn (ObservationFunction): Optional multi-agent observation_fn (ObservationFunction): Optional multi-agent
observation func to use for preprocessing observations. observation func to use for preprocessing observations.
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` to make generic trajectory views
available to Models. Default: False.
Returns: Returns:
Tuple: Tuple:
@ -811,18 +842,24 @@ def _do_policy_eval(
to_eval: Dict[PolicyID, List[PolicyEvalData]], to_eval: Dict[PolicyID, List[PolicyEvalData]],
policies: Dict[PolicyID, Policy], policies: Dict[PolicyID, Policy],
active_episodes: Dict[str, MultiAgentEpisode], active_episodes: Dict[str, MultiAgentEpisode],
tf_sess=None tf_sess=None,
_use_trajectory_view_api=False
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]: ) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
"""Call compute_actions on collected episode/model data to get next action. """Call compute_actions on collected episode/model data to get next action.
Args: Args:
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
IDs to lists of PolicyEvalData objects (items in these lists will
be the batch's items for the model forward pass).
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
obj.
active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
episode ID to currently ongoing MultiAgentEpisode object.
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
batching TF policy evaluations. batching TF policy evaluations.
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy IDs _use_trajectory_view_api (bool): Whether to use the (experimental)
to lists of PolicyEvalData objects. `_use_trajectory_view_api` procedure to collect samples.
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy. Default: False.
active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
episode ID to currently ongoing MultiAgentEpisode object.
Returns: Returns:
eval_results: dict of policy to compute_action() outputs. eval_results: dict of policy to compute_action() outputs.
@ -888,11 +925,17 @@ def _do_policy_eval(
def _process_policy_eval_results( def _process_policy_eval_results(
*, to_eval: Dict[PolicyID, List[PolicyEvalData]], eval_results: Dict[ *,
PolicyID, Tuple[TensorStructType, StateBatch, dict]], to_eval: Dict[PolicyID, List[PolicyEvalData]],
active_episodes: Dict[str, MultiAgentEpisode], active_envs: Set[int], eval_results: Dict[PolicyID, Tuple[
off_policy_actions: MultiEnvDict, policies: Dict[PolicyID, Policy], TensorStructType, StateBatch, dict]],
clip_actions: bool) -> Dict[EnvID, Dict[AgentID, EnvActionType]]: active_episodes: Dict[str, MultiAgentEpisode],
active_envs: Set[int],
off_policy_actions: MultiEnvDict,
policies: Dict[PolicyID, Policy],
clip_actions: bool,
_use_trajectory_view_api: bool = False
) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
"""Process the output of policy neural network evaluation. """Process the output of policy neural network evaluation.
Records policy evaluation results into the given episode objects and Records policy evaluation results into the given episode objects and
@ -911,6 +954,9 @@ def _process_policy_eval_results(
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy. policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy.
clip_actions (bool): Whether to clip actions to the action space's clip_actions (bool): Whether to clip actions to the action space's
bounds. bounds.
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` to make generic trajectory views
available to Models. Default: False.
Returns: Returns:
actions_to_send: Nested dict of env id -> agent id -> agent replies. actions_to_send: Nested dict of env id -> agent id -> agent replies.

View file

@ -0,0 +1,70 @@
from gym.spaces import Box, Discrete
import numpy as np
import unittest
from ray.rllib.evaluation.trajectory import Trajectory
class TestTrajectories(unittest.TestCase):
"""Tests Trajectory classes."""
def test_trajectory(self):
"""Tests the Trajectory class."""
buffer_size = 5
# Small trajecory object for testing purposes.
trajectory = Trajectory(buffer_size=buffer_size)
self.assertEqual(trajectory.cursor, 0)
self.assertEqual(trajectory.timestep, 0)
self.assertEqual(trajectory.sample_batch_offset, 0)
assert not trajectory.buffers
observation_space = Box(-1.0, 1.0, shape=(3, ))
action_space = Discrete(2)
trajectory.add_init_obs(
env_id=0,
agent_id="agent",
policy_id="policy",
init_obs=observation_space.sample())
self.assertEqual(trajectory.cursor, 0)
self.assertEqual(trajectory.initial_obs.shape, observation_space.shape)
# Fill up the buffer and make it extend if it hits the limit.
cur_buffer_size = buffer_size
for i in range(buffer_size + 1):
trajectory.add_action_reward_next_obs(
env_id=0,
agent_id="agent",
policy_id="policy",
values=dict(
t=i,
actions=action_space.sample(),
rewards=1.0,
dones=i == buffer_size,
new_obs=observation_space.sample(),
action_logp=-0.5,
action_dist_inputs=np.array([[0.5, 0.5]]),
))
self.assertEqual(trajectory.cursor, i + 1)
self.assertEqual(trajectory.timestep, i + 1)
self.assertEqual(trajectory.sample_batch_offset, 0)
if i == buffer_size - 1:
cur_buffer_size *= 2
self.assertEqual(
len(trajectory.buffers["new_obs"]), cur_buffer_size)
self.assertEqual(
len(trajectory.buffers["rewards"]), cur_buffer_size)
# Create a SampleBatch from the Trajectory and reset it.
batch = trajectory.get_sample_batch_and_reset()
self.assertEqual(batch.count, buffer_size + 1)
# Make sure, Trajectory was reset properly.
self.assertEqual(trajectory.cursor, buffer_size + 1)
self.assertEqual(trajectory.timestep, 0)
self.assertEqual(trajectory.sample_batch_offset, buffer_size + 1)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -0,0 +1,267 @@
import logging
import numpy as np
from typing import Dict, Optional
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.types import AgentID, EnvID, PolicyID, TensorType
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
logger = logging.getLogger(__name__)
def to_float_array(v):
if torch and isinstance(v[0], torch.Tensor):
arr = torch.stack(v).numpy() # np.array([s.numpy() for s in v])
else:
arr = np.array(v)
if arr.dtype == np.float64:
return arr.astype(np.float32) # save some memory
return arr
class Trajectory:
"""A trajectory of a (single) agent throughout one episode.
Note: This is an experimental class only used when
`config._use_trajectory_view_api` = True.
Collects all data produced by the environment during stepping of the agent
as well as all model outputs associated with the agent's Policy into
pre-allocated buffers of n timesteps capacity (`self.buffer_size`).
NOTE: A Trajectory object may contain remainders of a previous trajectory,
however, these are only kept for avoiding memory re-allocations. A
convenience cursor and offset-pointers allow for only "viewing" the
currently ongoing trajectory.
Memory re-allocation into larger buffers (`self.buffer_size *= 2`) only
happens if unavoidable (in case the buffer is full AND the currently
ongoing trajectory (episode) takes more than half of the buffer). In all
other cases, the same buffer is used for succeeding episodes/trejactories
(even for different agents).
"""
# Disambiguate unrolls within a single episode.
_next_unroll_id = 0
def __init__(self, buffer_size: Optional[int] = None):
"""Initializes a Trajectory object.
Args:
buffer_size (Optional[int]): The max number of timesteps to
fit into one buffer column. When re-allocating
"""
# The current occupant (agent X in env Y using policy Z) of our
# buffers.
self.env_id: EnvID = None
self.agent_id: AgentID = None
self.policy_id: PolicyID = None
# Determine the size of the initial buffers.
self.buffer_size = buffer_size or 1000
# The actual buffer holding dict (by column name (str) ->
# numpy/torch/tf tensors).
self.buffers = {}
# Holds the initial observation data.
self.initial_obs = None
# Cursor into the preallocated buffers. This is where all new data
# gets inserted.
self.cursor: int = 0
# The offset inside our buffer where the current trajectory starts.
self.trajectory_offset: int = 0
# The offset inside our buffer, from where to build the next
# SampleBatch.
self.sample_batch_offset: int = 0
@property
def timestep(self) -> int:
"""The timestep in the (currently ongoing) trajectory/episode."""
return self.cursor - self.trajectory_offset
def add_init_obs(self,
env_id: EnvID,
agent_id: AgentID,
policy_id: PolicyID,
init_obs: TensorType) -> None:
"""Adds a single initial observation (after env.reset()) to the buffer.
Stores it in self.initial_obs.
Args:
env_id (EnvID): Unique id for the episode we are adding the initial
observation for.
agent_id (AgentID): Unique id for the agent we are adding the
initial observation for.
policy_id (PolicyID): Unique id for policy controlling the agent.
init_obs (TensorType): Initial observation (after env.reset()).
"""
self.env_id = env_id
self.agent_id = agent_id
self.policy_id = policy_id
self.initial_obs = init_obs
def add_action_reward_next_obs(self,
env_id: EnvID,
agent_id: AgentID,
policy_id: PolicyID,
values: Dict[str, TensorType]) -> None:
"""Add the given dictionary (row) of values to this batch.
Args:
env_id (EnvID): Unique id for the episode we are adding the initial
observation for.
agent_id (AgentID): Unique id for the agent we are adding the
initial observation for.
policy_id (PolicyID): Unique id for policy controlling the agent.
values (Dict[str, TensorType]): Data dict (interpreted as a single
row) to be added to buffer. Must contain keys:
SampleBatch.ACTIONS, REWARDS, DONES, and OBS.
"""
assert self.initial_obs is not None
assert (SampleBatch.ACTIONS in values and SampleBatch.REWARDS in values
and SampleBatch.NEXT_OBS in values)
assert env_id == self.env_id
assert agent_id == self.agent_id
assert policy_id == self.policy_id
# Only obs exists so far in buffers:
# Initialize all other columns.
if len(self.buffers) == 0:
self._build_buffers(single_row=values)
for k, v in values.items():
self.buffers[k][self.cursor] = v
self.cursor += 1
# Extend (re-alloc) buffers if full.
if self.cursor == self.buffer_size:
self._extend_buffers(values)
def get_sample_batch_and_reset(self) -> SampleBatch:
"""Returns a SampleBatch carrying all previously added data.
If a reset happens and the trajectory is not done yet, we'll keep the
entire ongoing trajectory in memory for Model view requirement purposes
and only actually free the data, once the episode ends.
Returns:
SampleBatch: The SampleBatch containing this agent's data for the
entire trajectory (so far). The trajectory may not be
terminated yet. This SampleBatch object will contain a
`_last_obs` property, which contains the last observation for
this agent. This should be used by postprocessing functions
instead of the SampleBatch.NEXT_OBS field, which is deprecated.
"""
assert SampleBatch.UNROLL_ID not in self.buffers
# Convert all our data to numpy arrays, compress float64 to float32,
# and add the last observation data as well (always one more obs than
# all other columns due to the additional obs returned by Env.reset()).
data = {}
for k, v in self.buffers.items():
data[k] = to_float_array(
v[self.sample_batch_offset:self.cursor])
# Add unroll ID column to batch if non-existent.
uid = Trajectory._next_unroll_id
data[SampleBatch.UNROLL_ID] = np.repeat(
uid, self.cursor - self.sample_batch_offset)
inputs = {uid: {}}
if "t" in self.buffers:
if self.buffers["t"][self.sample_batch_offset] > 0:
for k in self.buffers.keys():
inputs[uid][k] = \
self.buffers[k][self.sample_batch_offset - 1]
else:
inputs[uid][SampleBatch.NEXT_OBS] = self.initial_obs
else:
inputs[uid][SampleBatch.NEXT_OBS] = self.initial_obs
Trajectory._next_unroll_id += 1
batch = SampleBatch(data, _initial_inputs=inputs)
# If done at end -> We can reset our buffers entirely.
if self.buffers[SampleBatch.DONES][self.cursor - 1]:
# Set self.timestep to 0 -> new trajectory w/o re-alloc (not yet,
# only ever re-alloc when necessary).
self.trajectory_offset = self.sample_batch_offset = self.cursor
# No done at end -> leave trajectory_offset as is (trajectory is still
# ongoing), but move the sample_batch offset to cursor.
else:
self.sample_batch_offset = self.cursor
return batch
def _build_buffers(self, single_row):
"""Creates zero-filled pre-allocated numpy buffers for data collection.
Except for the obs-column, which should already be initialized (done
on call to `self.add_initial_observation()`).
Args:
single_row (Dict[str,np.ndarray]): Dict of column names (keys) and
sample numpy data (values). Note: Only one of `single_data` or
`data_batch` must be provided.
"""
for col, data in single_row.items():
# Skip already initialized ones, e.g. 'obs' if used with
# add_initial_observation.
if col in self.buffers:
continue
self.buffers[col] = [None] * self.buffer_size
def _extend_buffers(self, single_row):
"""Extends the buffers (depending on trajectory state/length).
- Extend all buffer lists (x2) if trajectory starts at 0 (trajectory is
longer than current self.buffer_size).
- Trajectory starts in first half of buffer: Create new buffer lists
(2x buffer sizes) and move Trajectory to beginning of new buffer.
- Trajectory starts in last half of buffer: Leave buffer as is, but
move trajectory to very front (cursor=0).
Args:
single_row (dict): Data dict example to use in case we have to
re-build buffer.
"""
traj_length = self.cursor - self.trajectory_offset
# Trajectory starts at 0 (meaning episodes are longer than current
# `self.buffer_size` -> Simply do a resize (enlarge) on each column
# in the buffer.
if self.trajectory_offset == 0:
# Double actual horizon.
for col, data in self.buffers.items():
self.buffers[col].extend([None] * self.buffer_size)
self.buffer_size *= 2
# Trajectory starts in first half of the buffer -> Reallocate a new
# buffer and copy the currently ongoing trajectory into the new buffer.
elif self.trajectory_offset < self.buffer_size / 2:
# Double actual horizon.
self.buffer_size *= 2
# Store currently ongoing trajectory and build a new buffer.
old_buffers = self.buffers
self.buffers = {}
self._build_buffers(single_row)
# Copy the still ongoing trajectory into the new buffer.
for col, data in old_buffers.items():
self.buffers[col][:traj_length] = data[self.trajectory_offset:
self.cursor]
# Do an efficient memory swap: Move current trajectory simply to
# the beginning of the buffer (no reallocation/None-padding necessary).
else:
for col, data in self.buffers.items():
self.buffers[col][:traj_length] = self.buffers[col][
self.trajectory_offset:self.cursor]
# Set all pointers to their correct new values.
self.sample_batch_offset = (
self.sample_batch_offset - self.trajectory_offset)
self.trajectory_offset = 0
self.cursor = traj_length

View file

@ -266,6 +266,10 @@ class ModelV2:
# Single requirement: Pass current obs as input. # Single requirement: Pass current obs as input.
return { return {
SampleBatch.CUR_OBS: ViewRequirement(timesteps=0), SampleBatch.CUR_OBS: ViewRequirement(timesteps=0),
SampleBatch.PREV_ACTIONS:
ViewRequirement(SampleBatch.ACTIONS, timesteps=-1),
SampleBatch.PREV_REWARDS:
ViewRequirement(SampleBatch.REWARDS, timesteps=-1),
} }
def import_from_h5(self, h5_file): def import_from_h5(self, h5_file):
@ -338,7 +342,7 @@ class NullContextManager:
@DeveloperAPI @DeveloperAPI
def flatten(obs, framework): def flatten(obs, framework):
"""Flatten the given tensor.""" """Flatten the given tensor."""
if framework in ["tf", "tfe"]: if framework in ["tf2", "tf", "tfe"]:
return tf1.keras.layers.Flatten()(obs) return tf1.keras.layers.Flatten()(obs)
elif framework == "torch": elif framework == "torch":
assert torch is not None assert torch is not None

View file

@ -217,7 +217,7 @@ class Policy(metaclass=ABCMeta):
def compute_actions_from_trajectories( def compute_actions_from_trajectories(
self, self,
trajectories: List["Trajectory"], trajectories: List["Trajectory"],
other_trajectories: Dict[AgentID, "Trajectory"], other_trajectories: Optional[Dict[AgentID, "Trajectory"]] = None,
explore: bool = None, explore: bool = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
**kwargs) -> \ **kwargs) -> \
@ -226,14 +226,14 @@ class Policy(metaclass=ABCMeta):
Note: This is an experimental API method. Note: This is an experimental API method.
Only used so far by the Sampler iff `_fast_sampling=True` (also only Only used so far by the Sampler iff `_use_trajectory_view_api=True`
supported for torch). (also only supported for torch).
Args: Args:
trajectories (List[Trajectory]): A List of Trajectory data used trajectories (List[Trajectory]): A List of Trajectory data used
to create a view for the Model forward call. to create a view for the Model forward call.
other_trajectories (Dict[AgentID, Trajectory]): Optional dict other_trajectories (Optional[Dict[AgentID, Trajectory]]): Optional
mapping AgentIDs to Trajectory objects. dict mapping AgentIDs to Trajectory objects.
explore (bool): Whether to pick an exploitation or exploration explore (bool): Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]). action (default: None -> use self.config["explore"]).
timestep (Optional[int]): The current (sampling) time step. timestep (Optional[int]): The current (sampling) time step.

View file

@ -58,6 +58,8 @@ class SampleBatch:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor).""" """Constructs a sample batch (same params as dict constructor)."""
self._initial_inputs = kwargs.pop("_initial_inputs", {})
self.data = dict(*args, **kwargs) self.data = dict(*args, **kwargs)
lengths = [] lengths = []
for k, v in self.data.copy().items(): for k, v in self.data.copy().items():

View file

@ -168,7 +168,7 @@ class TorchPolicy(Policy):
def compute_actions_from_trajectories( def compute_actions_from_trajectories(
self, self,
trajectories: List["Trajectory"], trajectories: List["Trajectory"],
other_trajectories: Dict[AgentID, "Trajectory"], other_trajectories: Optional[Dict[AgentID, "Trajectory"]] = None,
explore: bool = None, explore: bool = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
**kwargs) -> \ **kwargs) -> \

View file

@ -1,4 +1,5 @@
import numpy as np import numpy as np
import tree
from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.framework import try_import_tf, try_import_torch
@ -247,3 +248,38 @@ def lstm(x,
unrolled_outputs[:, t, :] = h_states unrolled_outputs[:, t, :] = h_states
return unrolled_outputs, (c_states, h_states) return unrolled_outputs, (c_states, h_states)
# TODO: (sven) this will replace `TorchPolicy._convert_to_non_torch_tensor()`.
def convert_to_numpy(x, reduce_floats=False):
"""Converts values in `stats` to non-Tensor numpy or python types.
Args:
stats (any): Any (possibly nested) struct, the values in which will be
converted and returned as a new struct with all torch/tf tensors
being converted to numpy types.
reduce_floats (bool): Whether to reduce all float64 data into float32
automatically.
Returns:
Any: A new struct with the same structure as `stats`, but with all
values converted to numpy arrays (on CPU).
"""
# The mapping function used to numpyize torch/tf Tensors (and move them
# to the CPU beforehand).
def mapping(item):
if torch and isinstance(item, torch.Tensor):
ret = item.cpu().item() if len(item.size()) == 0 else \
item.cpu().detach().numpy()
elif tf and isinstance(item, tf.Tensor):
assert tf.executing_eagerly()
ret = item.cpu().numpy()
else:
ret = item
if reduce_floats and isinstance(ret, np.ndarray) and \
ret.dtype == np.float64:
ret = ret.astype(np.float32)
return ret
return tree.map_structure(mapping, x)