mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Layout of Trajectory View API (new class: Trajectory; not used yet). (#9269)
This commit is contained in:
parent
222635b63f
commit
03ab86567f
10 changed files with 485 additions and 43 deletions
|
@ -213,6 +213,13 @@ COMMON_CONFIG = {
|
||||||
# Use a background thread for sampling (slightly off-policy, usually not
|
# Use a background thread for sampling (slightly off-policy, usually not
|
||||||
# advisable to turn on unless your env specifically requires it).
|
# advisable to turn on unless your env specifically requires it).
|
||||||
"sample_async": False,
|
"sample_async": False,
|
||||||
|
|
||||||
|
# Experimental flag to speed up sampling and use "trajectory views" as
|
||||||
|
# generic ModelV2 `input_dicts` that can be requested by the model to
|
||||||
|
# contain different information on the ongoing episode.
|
||||||
|
# NOTE: Only supported for PyTorch so far.
|
||||||
|
"_use_trajectory_view_api": False,
|
||||||
|
|
||||||
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter".
|
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter".
|
||||||
"observation_filter": "NoFilter",
|
"observation_filter": "NoFilter",
|
||||||
# Whether to synchronize the statistics of remote filters.
|
# Whether to synchronize the statistics of remote filters.
|
||||||
|
@ -1057,6 +1064,11 @@ class Trainer(Trainable):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_config(config: PartialTrainerConfigDict):
|
def _validate_config(config: PartialTrainerConfigDict):
|
||||||
|
if config.get("_use_trajectory_view_api") and \
|
||||||
|
config.get("framework") != "torch":
|
||||||
|
raise ValueError(
|
||||||
|
"`_use_trajectory_view_api` only supported for PyTorch so "
|
||||||
|
"far!")
|
||||||
if "policy_graphs" in config["multiagent"]:
|
if "policy_graphs" in config["multiagent"]:
|
||||||
deprecation_warning("policy_graphs", "policies")
|
deprecation_warning("policy_graphs", "policies")
|
||||||
# Backwards compatibility.
|
# Backwards compatibility.
|
||||||
|
|
|
@ -496,7 +496,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
blackhole_outputs="simulation" in input_evaluation,
|
blackhole_outputs="simulation" in input_evaluation,
|
||||||
soft_horizon=soft_horizon,
|
soft_horizon=soft_horizon,
|
||||||
no_done_at_end=no_done_at_end,
|
no_done_at_end=no_done_at_end,
|
||||||
observation_fn=observation_fn)
|
observation_fn=observation_fn,
|
||||||
|
_use_trajectory_view_api=policy_config.get(
|
||||||
|
"_use_trajectory_view_api", False))
|
||||||
# Start the Sampler thread.
|
# Start the Sampler thread.
|
||||||
self.sampler.start()
|
self.sampler.start()
|
||||||
else:
|
else:
|
||||||
|
@ -516,7 +518,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
clip_actions=clip_actions,
|
clip_actions=clip_actions,
|
||||||
soft_horizon=soft_horizon,
|
soft_horizon=soft_horizon,
|
||||||
no_done_at_end=no_done_at_end,
|
no_done_at_end=no_done_at_end,
|
||||||
observation_fn=observation_fn)
|
observation_fn=observation_fn,
|
||||||
|
_use_trajectory_view_api=policy_config.get(
|
||||||
|
"_use_trajectory_view_api", False))
|
||||||
|
|
||||||
self.input_reader: InputReader = input_creator(self.io_context)
|
self.input_reader: InputReader = input_creator(self.io_context)
|
||||||
self.output_writer: OutputWriter = output_creator(self.io_context)
|
self.output_writer: OutputWriter = output_creator(self.io_context)
|
||||||
|
@ -561,7 +565,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
batch = self.input_reader.next()
|
batch = self.input_reader.next()
|
||||||
steps_so_far += batch.count
|
steps_so_far += batch.count
|
||||||
batches.append(batch)
|
batches.append(batch)
|
||||||
batch = batches[0].concat_samples(batches)
|
batch = batches[0].concat_samples(batches) if len(batches) > 1 else \
|
||||||
|
batches[0]
|
||||||
|
|
||||||
self.callbacks.on_sample_end(worker=self, samples=batch)
|
self.callbacks.on_sample_end(worker=self, samples=batch)
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,8 @@ import numpy as np
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Callable, Set, Tuple, Any, Iterable, Union, \
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, \
|
||||||
TYPE_CHECKING
|
TYPE_CHECKING, Union
|
||||||
|
|
||||||
from ray.util.debug import log_once
|
from ray.util.debug import log_once
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||||
|
@ -113,7 +113,8 @@ class SyncSampler(SamplerInput):
|
||||||
clip_actions: bool = True,
|
clip_actions: bool = True,
|
||||||
soft_horizon: bool = False,
|
soft_horizon: bool = False,
|
||||||
no_done_at_end: bool = False,
|
no_done_at_end: bool = False,
|
||||||
observation_fn: "ObservationFunction" = None):
|
observation_fn: "ObservationFunction" = None,
|
||||||
|
_use_trajectory_view_api: bool = False):
|
||||||
"""Initializes a SyncSampler object.
|
"""Initializes a SyncSampler object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -150,6 +151,9 @@ class SyncSampler(SamplerInput):
|
||||||
observation_fn (Optional[ObservationFunction]): Optional
|
observation_fn (Optional[ObservationFunction]): Optional
|
||||||
multi-agent observation func to use for preprocessing
|
multi-agent observation func to use for preprocessing
|
||||||
observations.
|
observations.
|
||||||
|
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||||
|
`_use_trajectory_view_api` to make generic trajectory views
|
||||||
|
available to Models. Default: False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.base_env = BaseEnv.to_base_env(env)
|
self.base_env = BaseEnv.to_base_env(env)
|
||||||
|
@ -167,7 +171,8 @@ class SyncSampler(SamplerInput):
|
||||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||||
pack_multiple_episodes_in_batch, callbacks, tf_sess,
|
pack_multiple_episodes_in_batch, callbacks, tf_sess,
|
||||||
self.perf_stats, soft_horizon, no_done_at_end, observation_fn)
|
self.perf_stats, soft_horizon, no_done_at_end, observation_fn,
|
||||||
|
_use_trajectory_view_api)
|
||||||
self.metrics_queue = queue.Queue()
|
self.metrics_queue = queue.Queue()
|
||||||
|
|
||||||
@override(SamplerInput)
|
@override(SamplerInput)
|
||||||
|
@ -227,7 +232,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
blackhole_outputs: bool = False,
|
blackhole_outputs: bool = False,
|
||||||
soft_horizon: bool = False,
|
soft_horizon: bool = False,
|
||||||
no_done_at_end: bool = False,
|
no_done_at_end: bool = False,
|
||||||
observation_fn: "ObservationFunction" = None):
|
observation_fn: "ObservationFunction" = None,
|
||||||
|
_use_trajectory_view_api: bool = False):
|
||||||
"""Initializes a AsyncSampler object.
|
"""Initializes a AsyncSampler object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -266,6 +272,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
observation_fn (Optional[ObservationFunction]): Optional
|
observation_fn (Optional[ObservationFunction]): Optional
|
||||||
multi-agent observation func to use for preprocessing
|
multi-agent observation func to use for preprocessing
|
||||||
observations.
|
observations.
|
||||||
|
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||||
|
`_use_trajectory_view_api` to make generic trajectory views
|
||||||
|
available to Models. Default: False.
|
||||||
"""
|
"""
|
||||||
for _, f in obs_filters.items():
|
for _, f in obs_filters.items():
|
||||||
assert getattr(f, "is_concurrent", False), \
|
assert getattr(f, "is_concurrent", False), \
|
||||||
|
@ -294,6 +303,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
self.perf_stats = _PerfStats()
|
self.perf_stats = _PerfStats()
|
||||||
self.shutdown = False
|
self.shutdown = False
|
||||||
self.observation_fn = observation_fn
|
self.observation_fn = observation_fn
|
||||||
|
self._use_trajectory_view_api = _use_trajectory_view_api
|
||||||
|
|
||||||
@override(threading.Thread)
|
@override(threading.Thread)
|
||||||
def run(self):
|
def run(self):
|
||||||
|
@ -317,7 +327,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||||
self.clip_actions, self.pack_multiple_episodes_in_batch,
|
self.clip_actions, self.pack_multiple_episodes_in_batch,
|
||||||
self.callbacks, self.tf_sess, self.perf_stats, self.soft_horizon,
|
self.callbacks, self.tf_sess, self.perf_stats, self.soft_horizon,
|
||||||
self.no_done_at_end, self.observation_fn)
|
self.no_done_at_end, self.observation_fn,
|
||||||
|
self._use_trajectory_view_api)
|
||||||
while not self.shutdown:
|
while not self.shutdown:
|
||||||
# The timeout variable exists because apparently, if one worker
|
# The timeout variable exists because apparently, if one worker
|
||||||
# dies, the other workers won't die with it, unless the timeout is
|
# dies, the other workers won't die with it, unless the timeout is
|
||||||
|
@ -362,24 +373,34 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
return extra
|
return extra
|
||||||
|
|
||||||
|
|
||||||
def _env_runner(
|
def _env_runner(worker: "RolloutWorker",
|
||||||
worker: "RolloutWorker", base_env: BaseEnv,
|
base_env: BaseEnv,
|
||||||
extra_batch_callback: Callable[[SampleBatchType], None], policies,
|
extra_batch_callback: Callable[[SampleBatchType], None],
|
||||||
policy_mapping_fn: Callable[[AgentID], PolicyID],
|
policies: Dict[PolicyID, Policy],
|
||||||
rollout_fragment_length: int, horizon: int,
|
policy_mapping_fn: Callable[[AgentID], PolicyID],
|
||||||
preprocessors: Dict[PolicyID, Preprocessor],
|
rollout_fragment_length: int,
|
||||||
obs_filters: Dict[PolicyID, Filter], clip_rewards: bool,
|
horizon: int,
|
||||||
clip_actions: bool, pack_multiple_episodes_in_batch: bool,
|
preprocessors: Dict[PolicyID, Preprocessor],
|
||||||
callbacks: "DefaultCallbacks", tf_sess, perf_stats: _PerfStats,
|
obs_filters: Dict[PolicyID, Filter],
|
||||||
soft_horizon: bool, no_done_at_end: bool,
|
clip_rewards: bool,
|
||||||
observation_fn: "ObservationFunction") -> Iterable[SampleBatchType]:
|
clip_actions: bool,
|
||||||
|
pack_multiple_episodes_in_batch: bool,
|
||||||
|
callbacks: "DefaultCallbacks",
|
||||||
|
tf_sess: Optional["tf.Session"],
|
||||||
|
perf_stats: _PerfStats,
|
||||||
|
soft_horizon: bool,
|
||||||
|
no_done_at_end: bool,
|
||||||
|
observation_fn: "ObservationFunction",
|
||||||
|
_use_trajectory_view_api: bool = False
|
||||||
|
) -> Iterable[SampleBatchType]:
|
||||||
"""This implements the common experience collection logic.
|
"""This implements the common experience collection logic.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker (RolloutWorker): Reference to the current rollout worker.
|
worker (RolloutWorker): Reference to the current rollout worker.
|
||||||
base_env (BaseEnv): Env implementing BaseEnv.
|
base_env (BaseEnv): Env implementing BaseEnv.
|
||||||
extra_batch_callback (fn): function to send extra batch data to.
|
extra_batch_callback (fn): function to send extra batch data to.
|
||||||
policies (dict): Map of policy ids to Policy instances.
|
policies (Dict[PolicyID, Policy]): Map of policy ids to Policy
|
||||||
|
instances.
|
||||||
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
||||||
This is called when an agent first enters the environment. The
|
This is called when an agent first enters the environment. The
|
||||||
agent is then "bound" to the returned policy for the episode.
|
agent is then "bound" to the returned policy for the episode.
|
||||||
|
@ -406,6 +427,9 @@ def _env_runner(
|
||||||
and instead record done=False.
|
and instead record done=False.
|
||||||
observation_fn (ObservationFunction): Optional multi-agent
|
observation_fn (ObservationFunction): Optional multi-agent
|
||||||
observation func to use for preprocessing observations.
|
observation func to use for preprocessing observations.
|
||||||
|
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||||
|
`_use_trajectory_view_api` to make generic trajectory views
|
||||||
|
available to Models. Default: False.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
rollout (SampleBatch): Object containing state, action, reward,
|
rollout (SampleBatch): Object containing state, action, reward,
|
||||||
|
@ -508,7 +532,8 @@ def _env_runner(
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
soft_horizon=soft_horizon,
|
soft_horizon=soft_horizon,
|
||||||
no_done_at_end=no_done_at_end,
|
no_done_at_end=no_done_at_end,
|
||||||
observation_fn=observation_fn)
|
observation_fn=observation_fn,
|
||||||
|
_use_trajectory_view_api=_use_trajectory_view_api)
|
||||||
perf_stats.processing_time += time.time() - t1
|
perf_stats.processing_time += time.time() - t1
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
yield o
|
yield o
|
||||||
|
@ -520,7 +545,8 @@ def _env_runner(
|
||||||
to_eval=to_eval,
|
to_eval=to_eval,
|
||||||
policies=policies,
|
policies=policies,
|
||||||
active_episodes=active_episodes,
|
active_episodes=active_episodes,
|
||||||
tf_sess=tf_sess)
|
tf_sess=tf_sess,
|
||||||
|
_use_trajectory_view_api=_use_trajectory_view_api)
|
||||||
perf_stats.inference_time += time.time() - t2
|
perf_stats.inference_time += time.time() - t2
|
||||||
|
|
||||||
# Process results and update episode state.
|
# Process results and update episode state.
|
||||||
|
@ -533,7 +559,8 @@ def _env_runner(
|
||||||
active_envs=active_envs,
|
active_envs=active_envs,
|
||||||
off_policy_actions=off_policy_actions,
|
off_policy_actions=off_policy_actions,
|
||||||
policies=policies,
|
policies=policies,
|
||||||
clip_actions=clip_actions)
|
clip_actions=clip_actions,
|
||||||
|
_use_trajectory_view_api=_use_trajectory_view_api)
|
||||||
perf_stats.processing_time += time.time() - t3
|
perf_stats.processing_time += time.time() - t3
|
||||||
|
|
||||||
# Return computed actions to ready envs. We also send to envs that have
|
# Return computed actions to ready envs. We also send to envs that have
|
||||||
|
@ -556,7 +583,8 @@ def _process_observations(
|
||||||
obs_filters: Dict[PolicyID, Filter], rollout_fragment_length: int,
|
obs_filters: Dict[PolicyID, Filter], rollout_fragment_length: int,
|
||||||
pack_multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks",
|
pack_multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks",
|
||||||
soft_horizon: bool, no_done_at_end: bool,
|
soft_horizon: bool, no_done_at_end: bool,
|
||||||
observation_fn: "ObservationFunction"
|
observation_fn: "ObservationFunction",
|
||||||
|
_use_trajectory_view_api: bool = False
|
||||||
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
|
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
|
||||||
RolloutMetrics, SampleBatchType]]]:
|
RolloutMetrics, SampleBatchType]]]:
|
||||||
"""Record new data from the environment and prepare for policy evaluation.
|
"""Record new data from the environment and prepare for policy evaluation.
|
||||||
|
@ -595,6 +623,9 @@ def _process_observations(
|
||||||
and instead record done=False.
|
and instead record done=False.
|
||||||
observation_fn (ObservationFunction): Optional multi-agent
|
observation_fn (ObservationFunction): Optional multi-agent
|
||||||
observation func to use for preprocessing observations.
|
observation func to use for preprocessing observations.
|
||||||
|
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||||
|
`_use_trajectory_view_api` to make generic trajectory views
|
||||||
|
available to Models. Default: False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple:
|
Tuple:
|
||||||
|
@ -811,18 +842,24 @@ def _do_policy_eval(
|
||||||
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
||||||
policies: Dict[PolicyID, Policy],
|
policies: Dict[PolicyID, Policy],
|
||||||
active_episodes: Dict[str, MultiAgentEpisode],
|
active_episodes: Dict[str, MultiAgentEpisode],
|
||||||
tf_sess=None
|
tf_sess=None,
|
||||||
|
_use_trajectory_view_api=False
|
||||||
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
||||||
"""Call compute_actions on collected episode/model data to get next action.
|
"""Call compute_actions on collected episode/model data to get next action.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
|
||||||
|
IDs to lists of PolicyEvalData objects (items in these lists will
|
||||||
|
be the batch's items for the model forward pass).
|
||||||
|
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
|
||||||
|
obj.
|
||||||
|
active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
|
||||||
|
episode ID to currently ongoing MultiAgentEpisode object.
|
||||||
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
|
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
|
||||||
batching TF policy evaluations.
|
batching TF policy evaluations.
|
||||||
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy IDs
|
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||||
to lists of PolicyEvalData objects.
|
`_use_trajectory_view_api` procedure to collect samples.
|
||||||
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy.
|
Default: False.
|
||||||
active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
|
|
||||||
episode ID to currently ongoing MultiAgentEpisode object.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
eval_results: dict of policy to compute_action() outputs.
|
eval_results: dict of policy to compute_action() outputs.
|
||||||
|
@ -888,11 +925,17 @@ def _do_policy_eval(
|
||||||
|
|
||||||
|
|
||||||
def _process_policy_eval_results(
|
def _process_policy_eval_results(
|
||||||
*, to_eval: Dict[PolicyID, List[PolicyEvalData]], eval_results: Dict[
|
*,
|
||||||
PolicyID, Tuple[TensorStructType, StateBatch, dict]],
|
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
||||||
active_episodes: Dict[str, MultiAgentEpisode], active_envs: Set[int],
|
eval_results: Dict[PolicyID, Tuple[
|
||||||
off_policy_actions: MultiEnvDict, policies: Dict[PolicyID, Policy],
|
TensorStructType, StateBatch, dict]],
|
||||||
clip_actions: bool) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
|
active_episodes: Dict[str, MultiAgentEpisode],
|
||||||
|
active_envs: Set[int],
|
||||||
|
off_policy_actions: MultiEnvDict,
|
||||||
|
policies: Dict[PolicyID, Policy],
|
||||||
|
clip_actions: bool,
|
||||||
|
_use_trajectory_view_api: bool = False
|
||||||
|
) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
|
||||||
"""Process the output of policy neural network evaluation.
|
"""Process the output of policy neural network evaluation.
|
||||||
|
|
||||||
Records policy evaluation results into the given episode objects and
|
Records policy evaluation results into the given episode objects and
|
||||||
|
@ -911,6 +954,9 @@ def _process_policy_eval_results(
|
||||||
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy.
|
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy.
|
||||||
clip_actions (bool): Whether to clip actions to the action space's
|
clip_actions (bool): Whether to clip actions to the action space's
|
||||||
bounds.
|
bounds.
|
||||||
|
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||||
|
`_use_trajectory_view_api` to make generic trajectory views
|
||||||
|
available to Models. Default: False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
actions_to_send: Nested dict of env id -> agent id -> agent replies.
|
actions_to_send: Nested dict of env id -> agent id -> agent replies.
|
||||||
|
|
70
rllib/evaluation/tests/test_trajectories.py
Normal file
70
rllib/evaluation/tests/test_trajectories.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
from gym.spaces import Box, Discrete
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from ray.rllib.evaluation.trajectory import Trajectory
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrajectories(unittest.TestCase):
|
||||||
|
"""Tests Trajectory classes."""
|
||||||
|
|
||||||
|
def test_trajectory(self):
|
||||||
|
"""Tests the Trajectory class."""
|
||||||
|
|
||||||
|
buffer_size = 5
|
||||||
|
|
||||||
|
# Small trajecory object for testing purposes.
|
||||||
|
trajectory = Trajectory(buffer_size=buffer_size)
|
||||||
|
self.assertEqual(trajectory.cursor, 0)
|
||||||
|
self.assertEqual(trajectory.timestep, 0)
|
||||||
|
self.assertEqual(trajectory.sample_batch_offset, 0)
|
||||||
|
assert not trajectory.buffers
|
||||||
|
observation_space = Box(-1.0, 1.0, shape=(3, ))
|
||||||
|
action_space = Discrete(2)
|
||||||
|
trajectory.add_init_obs(
|
||||||
|
env_id=0,
|
||||||
|
agent_id="agent",
|
||||||
|
policy_id="policy",
|
||||||
|
init_obs=observation_space.sample())
|
||||||
|
self.assertEqual(trajectory.cursor, 0)
|
||||||
|
self.assertEqual(trajectory.initial_obs.shape, observation_space.shape)
|
||||||
|
|
||||||
|
# Fill up the buffer and make it extend if it hits the limit.
|
||||||
|
cur_buffer_size = buffer_size
|
||||||
|
for i in range(buffer_size + 1):
|
||||||
|
trajectory.add_action_reward_next_obs(
|
||||||
|
env_id=0,
|
||||||
|
agent_id="agent",
|
||||||
|
policy_id="policy",
|
||||||
|
values=dict(
|
||||||
|
t=i,
|
||||||
|
actions=action_space.sample(),
|
||||||
|
rewards=1.0,
|
||||||
|
dones=i == buffer_size,
|
||||||
|
new_obs=observation_space.sample(),
|
||||||
|
action_logp=-0.5,
|
||||||
|
action_dist_inputs=np.array([[0.5, 0.5]]),
|
||||||
|
))
|
||||||
|
self.assertEqual(trajectory.cursor, i + 1)
|
||||||
|
self.assertEqual(trajectory.timestep, i + 1)
|
||||||
|
self.assertEqual(trajectory.sample_batch_offset, 0)
|
||||||
|
if i == buffer_size - 1:
|
||||||
|
cur_buffer_size *= 2
|
||||||
|
self.assertEqual(
|
||||||
|
len(trajectory.buffers["new_obs"]), cur_buffer_size)
|
||||||
|
self.assertEqual(
|
||||||
|
len(trajectory.buffers["rewards"]), cur_buffer_size)
|
||||||
|
|
||||||
|
# Create a SampleBatch from the Trajectory and reset it.
|
||||||
|
batch = trajectory.get_sample_batch_and_reset()
|
||||||
|
self.assertEqual(batch.count, buffer_size + 1)
|
||||||
|
# Make sure, Trajectory was reset properly.
|
||||||
|
self.assertEqual(trajectory.cursor, buffer_size + 1)
|
||||||
|
self.assertEqual(trajectory.timestep, 0)
|
||||||
|
self.assertEqual(trajectory.sample_batch_offset, buffer_size + 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
267
rllib/evaluation/trajectory.py
Normal file
267
rllib/evaluation/trajectory.py
Normal file
|
@ -0,0 +1,267 @@
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||||
|
from ray.rllib.utils.types import AgentID, EnvID, PolicyID, TensorType
|
||||||
|
|
||||||
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
torch, _ = try_import_torch()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def to_float_array(v):
|
||||||
|
if torch and isinstance(v[0], torch.Tensor):
|
||||||
|
arr = torch.stack(v).numpy() # np.array([s.numpy() for s in v])
|
||||||
|
else:
|
||||||
|
arr = np.array(v)
|
||||||
|
if arr.dtype == np.float64:
|
||||||
|
return arr.astype(np.float32) # save some memory
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
class Trajectory:
|
||||||
|
"""A trajectory of a (single) agent throughout one episode.
|
||||||
|
|
||||||
|
Note: This is an experimental class only used when
|
||||||
|
`config._use_trajectory_view_api` = True.
|
||||||
|
|
||||||
|
Collects all data produced by the environment during stepping of the agent
|
||||||
|
as well as all model outputs associated with the agent's Policy into
|
||||||
|
pre-allocated buffers of n timesteps capacity (`self.buffer_size`).
|
||||||
|
NOTE: A Trajectory object may contain remainders of a previous trajectory,
|
||||||
|
however, these are only kept for avoiding memory re-allocations. A
|
||||||
|
convenience cursor and offset-pointers allow for only "viewing" the
|
||||||
|
currently ongoing trajectory.
|
||||||
|
Memory re-allocation into larger buffers (`self.buffer_size *= 2`) only
|
||||||
|
happens if unavoidable (in case the buffer is full AND the currently
|
||||||
|
ongoing trajectory (episode) takes more than half of the buffer). In all
|
||||||
|
other cases, the same buffer is used for succeeding episodes/trejactories
|
||||||
|
(even for different agents).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Disambiguate unrolls within a single episode.
|
||||||
|
_next_unroll_id = 0
|
||||||
|
|
||||||
|
def __init__(self, buffer_size: Optional[int] = None):
|
||||||
|
"""Initializes a Trajectory object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer_size (Optional[int]): The max number of timesteps to
|
||||||
|
fit into one buffer column. When re-allocating
|
||||||
|
"""
|
||||||
|
# The current occupant (agent X in env Y using policy Z) of our
|
||||||
|
# buffers.
|
||||||
|
self.env_id: EnvID = None
|
||||||
|
self.agent_id: AgentID = None
|
||||||
|
self.policy_id: PolicyID = None
|
||||||
|
|
||||||
|
# Determine the size of the initial buffers.
|
||||||
|
self.buffer_size = buffer_size or 1000
|
||||||
|
# The actual buffer holding dict (by column name (str) ->
|
||||||
|
# numpy/torch/tf tensors).
|
||||||
|
self.buffers = {}
|
||||||
|
|
||||||
|
# Holds the initial observation data.
|
||||||
|
self.initial_obs = None
|
||||||
|
|
||||||
|
# Cursor into the preallocated buffers. This is where all new data
|
||||||
|
# gets inserted.
|
||||||
|
self.cursor: int = 0
|
||||||
|
# The offset inside our buffer where the current trajectory starts.
|
||||||
|
self.trajectory_offset: int = 0
|
||||||
|
# The offset inside our buffer, from where to build the next
|
||||||
|
# SampleBatch.
|
||||||
|
self.sample_batch_offset: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timestep(self) -> int:
|
||||||
|
"""The timestep in the (currently ongoing) trajectory/episode."""
|
||||||
|
return self.cursor - self.trajectory_offset
|
||||||
|
|
||||||
|
def add_init_obs(self,
|
||||||
|
env_id: EnvID,
|
||||||
|
agent_id: AgentID,
|
||||||
|
policy_id: PolicyID,
|
||||||
|
init_obs: TensorType) -> None:
|
||||||
|
"""Adds a single initial observation (after env.reset()) to the buffer.
|
||||||
|
|
||||||
|
Stores it in self.initial_obs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_id (EnvID): Unique id for the episode we are adding the initial
|
||||||
|
observation for.
|
||||||
|
agent_id (AgentID): Unique id for the agent we are adding the
|
||||||
|
initial observation for.
|
||||||
|
policy_id (PolicyID): Unique id for policy controlling the agent.
|
||||||
|
init_obs (TensorType): Initial observation (after env.reset()).
|
||||||
|
"""
|
||||||
|
self.env_id = env_id
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.policy_id = policy_id
|
||||||
|
self.initial_obs = init_obs
|
||||||
|
|
||||||
|
def add_action_reward_next_obs(self,
|
||||||
|
env_id: EnvID,
|
||||||
|
agent_id: AgentID,
|
||||||
|
policy_id: PolicyID,
|
||||||
|
values: Dict[str, TensorType]) -> None:
|
||||||
|
"""Add the given dictionary (row) of values to this batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_id (EnvID): Unique id for the episode we are adding the initial
|
||||||
|
observation for.
|
||||||
|
agent_id (AgentID): Unique id for the agent we are adding the
|
||||||
|
initial observation for.
|
||||||
|
policy_id (PolicyID): Unique id for policy controlling the agent.
|
||||||
|
values (Dict[str, TensorType]): Data dict (interpreted as a single
|
||||||
|
row) to be added to buffer. Must contain keys:
|
||||||
|
SampleBatch.ACTIONS, REWARDS, DONES, and OBS.
|
||||||
|
"""
|
||||||
|
assert self.initial_obs is not None
|
||||||
|
assert (SampleBatch.ACTIONS in values and SampleBatch.REWARDS in values
|
||||||
|
and SampleBatch.NEXT_OBS in values)
|
||||||
|
assert env_id == self.env_id
|
||||||
|
assert agent_id == self.agent_id
|
||||||
|
assert policy_id == self.policy_id
|
||||||
|
|
||||||
|
# Only obs exists so far in buffers:
|
||||||
|
# Initialize all other columns.
|
||||||
|
if len(self.buffers) == 0:
|
||||||
|
self._build_buffers(single_row=values)
|
||||||
|
|
||||||
|
for k, v in values.items():
|
||||||
|
self.buffers[k][self.cursor] = v
|
||||||
|
self.cursor += 1
|
||||||
|
|
||||||
|
# Extend (re-alloc) buffers if full.
|
||||||
|
if self.cursor == self.buffer_size:
|
||||||
|
self._extend_buffers(values)
|
||||||
|
|
||||||
|
def get_sample_batch_and_reset(self) -> SampleBatch:
|
||||||
|
"""Returns a SampleBatch carrying all previously added data.
|
||||||
|
|
||||||
|
If a reset happens and the trajectory is not done yet, we'll keep the
|
||||||
|
entire ongoing trajectory in memory for Model view requirement purposes
|
||||||
|
and only actually free the data, once the episode ends.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SampleBatch: The SampleBatch containing this agent's data for the
|
||||||
|
entire trajectory (so far). The trajectory may not be
|
||||||
|
terminated yet. This SampleBatch object will contain a
|
||||||
|
`_last_obs` property, which contains the last observation for
|
||||||
|
this agent. This should be used by postprocessing functions
|
||||||
|
instead of the SampleBatch.NEXT_OBS field, which is deprecated.
|
||||||
|
"""
|
||||||
|
assert SampleBatch.UNROLL_ID not in self.buffers
|
||||||
|
|
||||||
|
# Convert all our data to numpy arrays, compress float64 to float32,
|
||||||
|
# and add the last observation data as well (always one more obs than
|
||||||
|
# all other columns due to the additional obs returned by Env.reset()).
|
||||||
|
data = {}
|
||||||
|
for k, v in self.buffers.items():
|
||||||
|
data[k] = to_float_array(
|
||||||
|
v[self.sample_batch_offset:self.cursor])
|
||||||
|
|
||||||
|
# Add unroll ID column to batch if non-existent.
|
||||||
|
uid = Trajectory._next_unroll_id
|
||||||
|
data[SampleBatch.UNROLL_ID] = np.repeat(
|
||||||
|
uid, self.cursor - self.sample_batch_offset)
|
||||||
|
|
||||||
|
inputs = {uid: {}}
|
||||||
|
if "t" in self.buffers:
|
||||||
|
if self.buffers["t"][self.sample_batch_offset] > 0:
|
||||||
|
for k in self.buffers.keys():
|
||||||
|
inputs[uid][k] = \
|
||||||
|
self.buffers[k][self.sample_batch_offset - 1]
|
||||||
|
else:
|
||||||
|
inputs[uid][SampleBatch.NEXT_OBS] = self.initial_obs
|
||||||
|
else:
|
||||||
|
inputs[uid][SampleBatch.NEXT_OBS] = self.initial_obs
|
||||||
|
|
||||||
|
Trajectory._next_unroll_id += 1
|
||||||
|
|
||||||
|
batch = SampleBatch(data, _initial_inputs=inputs)
|
||||||
|
|
||||||
|
# If done at end -> We can reset our buffers entirely.
|
||||||
|
if self.buffers[SampleBatch.DONES][self.cursor - 1]:
|
||||||
|
# Set self.timestep to 0 -> new trajectory w/o re-alloc (not yet,
|
||||||
|
# only ever re-alloc when necessary).
|
||||||
|
self.trajectory_offset = self.sample_batch_offset = self.cursor
|
||||||
|
# No done at end -> leave trajectory_offset as is (trajectory is still
|
||||||
|
# ongoing), but move the sample_batch offset to cursor.
|
||||||
|
else:
|
||||||
|
self.sample_batch_offset = self.cursor
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def _build_buffers(self, single_row):
|
||||||
|
"""Creates zero-filled pre-allocated numpy buffers for data collection.
|
||||||
|
|
||||||
|
Except for the obs-column, which should already be initialized (done
|
||||||
|
on call to `self.add_initial_observation()`).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
single_row (Dict[str,np.ndarray]): Dict of column names (keys) and
|
||||||
|
sample numpy data (values). Note: Only one of `single_data` or
|
||||||
|
`data_batch` must be provided.
|
||||||
|
"""
|
||||||
|
for col, data in single_row.items():
|
||||||
|
# Skip already initialized ones, e.g. 'obs' if used with
|
||||||
|
# add_initial_observation.
|
||||||
|
if col in self.buffers:
|
||||||
|
continue
|
||||||
|
self.buffers[col] = [None] * self.buffer_size
|
||||||
|
|
||||||
|
def _extend_buffers(self, single_row):
|
||||||
|
"""Extends the buffers (depending on trajectory state/length).
|
||||||
|
|
||||||
|
- Extend all buffer lists (x2) if trajectory starts at 0 (trajectory is
|
||||||
|
longer than current self.buffer_size).
|
||||||
|
- Trajectory starts in first half of buffer: Create new buffer lists
|
||||||
|
(2x buffer sizes) and move Trajectory to beginning of new buffer.
|
||||||
|
- Trajectory starts in last half of buffer: Leave buffer as is, but
|
||||||
|
move trajectory to very front (cursor=0).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
single_row (dict): Data dict example to use in case we have to
|
||||||
|
re-build buffer.
|
||||||
|
"""
|
||||||
|
traj_length = self.cursor - self.trajectory_offset
|
||||||
|
|
||||||
|
# Trajectory starts at 0 (meaning episodes are longer than current
|
||||||
|
# `self.buffer_size` -> Simply do a resize (enlarge) on each column
|
||||||
|
# in the buffer.
|
||||||
|
if self.trajectory_offset == 0:
|
||||||
|
# Double actual horizon.
|
||||||
|
for col, data in self.buffers.items():
|
||||||
|
self.buffers[col].extend([None] * self.buffer_size)
|
||||||
|
self.buffer_size *= 2
|
||||||
|
|
||||||
|
# Trajectory starts in first half of the buffer -> Reallocate a new
|
||||||
|
# buffer and copy the currently ongoing trajectory into the new buffer.
|
||||||
|
elif self.trajectory_offset < self.buffer_size / 2:
|
||||||
|
# Double actual horizon.
|
||||||
|
self.buffer_size *= 2
|
||||||
|
# Store currently ongoing trajectory and build a new buffer.
|
||||||
|
old_buffers = self.buffers
|
||||||
|
self.buffers = {}
|
||||||
|
self._build_buffers(single_row)
|
||||||
|
# Copy the still ongoing trajectory into the new buffer.
|
||||||
|
for col, data in old_buffers.items():
|
||||||
|
self.buffers[col][:traj_length] = data[self.trajectory_offset:
|
||||||
|
self.cursor]
|
||||||
|
|
||||||
|
# Do an efficient memory swap: Move current trajectory simply to
|
||||||
|
# the beginning of the buffer (no reallocation/None-padding necessary).
|
||||||
|
else:
|
||||||
|
for col, data in self.buffers.items():
|
||||||
|
self.buffers[col][:traj_length] = self.buffers[col][
|
||||||
|
self.trajectory_offset:self.cursor]
|
||||||
|
|
||||||
|
# Set all pointers to their correct new values.
|
||||||
|
self.sample_batch_offset = (
|
||||||
|
self.sample_batch_offset - self.trajectory_offset)
|
||||||
|
self.trajectory_offset = 0
|
||||||
|
self.cursor = traj_length
|
|
@ -266,6 +266,10 @@ class ModelV2:
|
||||||
# Single requirement: Pass current obs as input.
|
# Single requirement: Pass current obs as input.
|
||||||
return {
|
return {
|
||||||
SampleBatch.CUR_OBS: ViewRequirement(timesteps=0),
|
SampleBatch.CUR_OBS: ViewRequirement(timesteps=0),
|
||||||
|
SampleBatch.PREV_ACTIONS:
|
||||||
|
ViewRequirement(SampleBatch.ACTIONS, timesteps=-1),
|
||||||
|
SampleBatch.PREV_REWARDS:
|
||||||
|
ViewRequirement(SampleBatch.REWARDS, timesteps=-1),
|
||||||
}
|
}
|
||||||
|
|
||||||
def import_from_h5(self, h5_file):
|
def import_from_h5(self, h5_file):
|
||||||
|
@ -338,7 +342,7 @@ class NullContextManager:
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def flatten(obs, framework):
|
def flatten(obs, framework):
|
||||||
"""Flatten the given tensor."""
|
"""Flatten the given tensor."""
|
||||||
if framework in ["tf", "tfe"]:
|
if framework in ["tf2", "tf", "tfe"]:
|
||||||
return tf1.keras.layers.Flatten()(obs)
|
return tf1.keras.layers.Flatten()(obs)
|
||||||
elif framework == "torch":
|
elif framework == "torch":
|
||||||
assert torch is not None
|
assert torch is not None
|
||||||
|
|
|
@ -217,7 +217,7 @@ class Policy(metaclass=ABCMeta):
|
||||||
def compute_actions_from_trajectories(
|
def compute_actions_from_trajectories(
|
||||||
self,
|
self,
|
||||||
trajectories: List["Trajectory"],
|
trajectories: List["Trajectory"],
|
||||||
other_trajectories: Dict[AgentID, "Trajectory"],
|
other_trajectories: Optional[Dict[AgentID, "Trajectory"]] = None,
|
||||||
explore: bool = None,
|
explore: bool = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
**kwargs) -> \
|
**kwargs) -> \
|
||||||
|
@ -226,14 +226,14 @@ class Policy(metaclass=ABCMeta):
|
||||||
|
|
||||||
Note: This is an experimental API method.
|
Note: This is an experimental API method.
|
||||||
|
|
||||||
Only used so far by the Sampler iff `_fast_sampling=True` (also only
|
Only used so far by the Sampler iff `_use_trajectory_view_api=True`
|
||||||
supported for torch).
|
(also only supported for torch).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trajectories (List[Trajectory]): A List of Trajectory data used
|
trajectories (List[Trajectory]): A List of Trajectory data used
|
||||||
to create a view for the Model forward call.
|
to create a view for the Model forward call.
|
||||||
other_trajectories (Dict[AgentID, Trajectory]): Optional dict
|
other_trajectories (Optional[Dict[AgentID, Trajectory]]): Optional
|
||||||
mapping AgentIDs to Trajectory objects.
|
dict mapping AgentIDs to Trajectory objects.
|
||||||
explore (bool): Whether to pick an exploitation or exploration
|
explore (bool): Whether to pick an exploitation or exploration
|
||||||
action (default: None -> use self.config["explore"]).
|
action (default: None -> use self.config["explore"]).
|
||||||
timestep (Optional[int]): The current (sampling) time step.
|
timestep (Optional[int]): The current (sampling) time step.
|
||||||
|
|
|
@ -58,6 +58,8 @@ class SampleBatch:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
"""Constructs a sample batch (same params as dict constructor)."""
|
"""Constructs a sample batch (same params as dict constructor)."""
|
||||||
|
|
||||||
|
self._initial_inputs = kwargs.pop("_initial_inputs", {})
|
||||||
|
|
||||||
self.data = dict(*args, **kwargs)
|
self.data = dict(*args, **kwargs)
|
||||||
lengths = []
|
lengths = []
|
||||||
for k, v in self.data.copy().items():
|
for k, v in self.data.copy().items():
|
||||||
|
|
|
@ -168,7 +168,7 @@ class TorchPolicy(Policy):
|
||||||
def compute_actions_from_trajectories(
|
def compute_actions_from_trajectories(
|
||||||
self,
|
self,
|
||||||
trajectories: List["Trajectory"],
|
trajectories: List["Trajectory"],
|
||||||
other_trajectories: Dict[AgentID, "Trajectory"],
|
other_trajectories: Optional[Dict[AgentID, "Trajectory"]] = None,
|
||||||
explore: bool = None,
|
explore: bool = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
**kwargs) -> \
|
**kwargs) -> \
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import tree
|
||||||
|
|
||||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||||
|
|
||||||
|
@ -247,3 +248,38 @@ def lstm(x,
|
||||||
unrolled_outputs[:, t, :] = h_states
|
unrolled_outputs[:, t, :] = h_states
|
||||||
|
|
||||||
return unrolled_outputs, (c_states, h_states)
|
return unrolled_outputs, (c_states, h_states)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: (sven) this will replace `TorchPolicy._convert_to_non_torch_tensor()`.
|
||||||
|
def convert_to_numpy(x, reduce_floats=False):
|
||||||
|
"""Converts values in `stats` to non-Tensor numpy or python types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stats (any): Any (possibly nested) struct, the values in which will be
|
||||||
|
converted and returned as a new struct with all torch/tf tensors
|
||||||
|
being converted to numpy types.
|
||||||
|
reduce_floats (bool): Whether to reduce all float64 data into float32
|
||||||
|
automatically.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: A new struct with the same structure as `stats`, but with all
|
||||||
|
values converted to numpy arrays (on CPU).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The mapping function used to numpyize torch/tf Tensors (and move them
|
||||||
|
# to the CPU beforehand).
|
||||||
|
def mapping(item):
|
||||||
|
if torch and isinstance(item, torch.Tensor):
|
||||||
|
ret = item.cpu().item() if len(item.size()) == 0 else \
|
||||||
|
item.cpu().detach().numpy()
|
||||||
|
elif tf and isinstance(item, tf.Tensor):
|
||||||
|
assert tf.executing_eagerly()
|
||||||
|
ret = item.cpu().numpy()
|
||||||
|
else:
|
||||||
|
ret = item
|
||||||
|
if reduce_floats and isinstance(ret, np.ndarray) and \
|
||||||
|
ret.dtype == np.float64:
|
||||||
|
ret = ret.astype(np.float32)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return tree.map_structure(mapping, x)
|
||||||
|
|
Loading…
Add table
Reference in a new issue