[RLlib] Trajectory view API (prep PR for switching on by default across all RLlib; plumbing only) (#11717)

This commit is contained in:
Sven Mika 2020-11-03 21:53:34 +01:00 committed by GitHub
parent c3074f559c
commit 5b788ccb13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 364 additions and 61 deletions

View file

@ -136,10 +136,11 @@ class _AgentCollector:
if data_col not in np_data: if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col]) np_data[data_col] = to_float_np_array(self.buffers[data_col])
if shift == 0: if shift == 0:
batch_data[view_col] = np_data[data_col][self.shift_before:] data = np_data[data_col][self.shift_before:]
else: else:
batch_data[view_col] = np_data[data_col][self.shift_before + data = np_data[data_col][self.shift_before + shift:shift]
shift:shift] if len(data) > 0:
batch_data[view_col] = data
batch = SampleBatch(batch_data) batch = SampleBatch(batch_data)
if SampleBatch.UNROLL_ID not in batch.data: if SampleBatch.UNROLL_ID not in batch.data:
@ -340,7 +341,7 @@ class _SimpleListCollector(_SampleCollector):
assert self.agent_key_to_policy[agent_key] == policy_id assert self.agent_key_to_policy[agent_key] == policy_id
policy = self.policy_map[policy_id] policy = self.policy_map[policy_id]
view_reqs = policy.model.inference_view_requirements if \ view_reqs = policy.model.inference_view_requirements if \
hasattr(policy, "model") else policy.view_requirements getattr(policy, "model", None) else policy.view_requirements
# Add initial obs to Trajectory. # Add initial obs to Trajectory.
assert agent_key not in self.agent_collectors assert agent_key not in self.agent_collectors
@ -388,7 +389,7 @@ class _SimpleListCollector(_SampleCollector):
keys = self.forward_pass_agent_keys[policy_id] keys = self.forward_pass_agent_keys[policy_id]
buffers = {k: self.agent_collectors[k].buffers for k in keys} buffers = {k: self.agent_collectors[k].buffers for k in keys}
view_reqs = policy.model.inference_view_requirements if \ view_reqs = policy.model.inference_view_requirements if \
hasattr(policy, "model") else policy.view_requirements getattr(policy, "model", None) else policy.view_requirements
input_dict = {} input_dict = {}
for view_col, view_req in view_reqs.items(): for view_col, view_req in view_reqs.items():
@ -447,19 +448,19 @@ class _SimpleListCollector(_SampleCollector):
for agent_id, (_, pre_batch) in pre_batches.items(): for agent_id, (_, pre_batch) in pre_batches.items():
# Entire episode is said to be done. # Entire episode is said to be done.
if is_done: # Error if no DONE at end of this agent's trajectory.
# Error if no DONE at end of this agent's trajectory. if is_done and check_dones and \
if check_dones and not pre_batch[SampleBatch.DONES][-1]: not pre_batch[SampleBatch.DONES][-1]:
raise ValueError( raise ValueError(
"Episode {} terminated for all agents, but we still " "Episode {} terminated for all agents, but we still don't "
"don't have a last observation for agent {} (policy " "don't have a last observation for agent {} (policy "
"{}). ".format( "{}). ".format(
episode_id, agent_id, self.agent_key_to_policy[( episode_id, agent_id, self.agent_key_to_policy[(
episode_id, agent_id)]) + episode_id, agent_id)]) +
"Please ensure that you include the last observations " "Please ensure that you include the last observations "
"of all live agents when setting done[__all__] to " "of all live agents when setting done[__all__] to "
"True. Alternatively, set no_done_at_end=True to " "True. Alternatively, set no_done_at_end=True to "
"allow this.") "allow this.")
# If (only this?) agent is done, erase its buffer entirely. # If (only this?) agent is done, erase its buffer entirely.
if pre_batch[SampleBatch.DONES][-1]: if pre_batch[SampleBatch.DONES][-1]:
del self.agent_collectors[(episode_id, agent_id)] del self.agent_collectors[(episode_id, agent_id)]

View file

@ -364,7 +364,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.clip_actions, self.multiple_episodes_in_batch, self.callbacks, self.clip_actions, self.multiple_episodes_in_batch, self.callbacks,
self.tf_sess, self.perf_stats, self.soft_horizon, self.tf_sess, self.perf_stats, self.soft_horizon,
self.no_done_at_end, self.observation_fn, self.no_done_at_end, self.observation_fn,
self._use_trajectory_view_api) self._use_trajectory_view_api, self.sample_collector)
while not self.shutdown: while not self.shutdown:
# The timeout variable exists because apparently, if one worker # The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is # dies, the other workers won't die with it, unless the timeout is
@ -613,6 +613,7 @@ def _env_runner(
to_eval=to_eval, to_eval=to_eval,
policies=policies, policies=policies,
_sample_collector=_sample_collector, _sample_collector=_sample_collector,
active_episodes=active_episodes,
tf_sess=tf_sess, tf_sess=tf_sess,
) )
else: else:
@ -1252,7 +1253,8 @@ def _do_policy_eval_w_trajectory_view_api(
to_eval: Dict[PolicyID, List[PolicyEvalData]], to_eval: Dict[PolicyID, List[PolicyEvalData]],
policies: Dict[PolicyID, Policy], policies: Dict[PolicyID, Policy],
_sample_collector, _sample_collector,
tf_sess=None, active_episodes: Dict[str, MultiAgentEpisode],
tf_sess: Optional["tf.Session"] = None,
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]: ) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
"""Call compute_actions on collected episode/model data to get next action. """Call compute_actions on collected episode/model data to get next action.
@ -1282,12 +1284,14 @@ def _do_policy_eval_w_trajectory_view_api(
logger.info("Inputs to compute_actions():\n\n{}\n".format( logger.info("Inputs to compute_actions():\n\n{}\n".format(
summarize(to_eval))) summarize(to_eval)))
for policy_id in to_eval.keys(): for policy_id, eval_data in to_eval.items():
policy: Policy = _get_or_raise(policies, policy_id) policy: Policy = _get_or_raise(policies, policy_id)
input_dict = _sample_collector.get_inference_input_dict(policy_id) input_dict = _sample_collector.get_inference_input_dict(policy_id)
eval_results[policy_id] = \ eval_results[policy_id] = \
policy.compute_actions_from_input_dict( policy.compute_actions_from_input_dict(
input_dict, timestep=policy.global_timestep) input_dict,
timestep=policy.global_timestep,
episodes=[active_episodes[t.env_id] for t in eval_data])
if builder: if builder:
# type: PolicyID, Tuple[TensorStructType, StateBatch, dict] # type: PolicyID, Tuple[TensorStructType, StateBatch, dict]

View file

@ -58,7 +58,6 @@ class Preprocessor:
observation = np.array(observation) observation = np.array(observation)
try: try:
if not self._obs_space.contains(observation): if not self._obs_space.contains(observation):
print()
raise ValueError( raise ValueError(
"Observation ({}) outside given space ({})!", "Observation ({}) outside given space ({})!",
observation, self._obs_space) observation, self._obs_space)

View file

@ -2,10 +2,12 @@ from collections import OrderedDict
import gym import gym
import logging import logging
import numpy as np import numpy as np
from typing import Callable, Dict, List, Optional, Tuple import re
from typing import Callable, Dict, List, Optional, Tuple, Type
from ray.util.debug import log_once from ray.util.debug import log_once
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.policy.tf_policy import TFPolicy
@ -13,6 +15,7 @@ from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import summarize from ray.rllib.utils.debug import summarize
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import get_placeholder
from ray.rllib.utils.tracking_dict import UsageTrackingDict from ray.rllib.utils.tracking_dict import UsageTrackingDict
from ray.rllib.utils.typing import ModelGradients, TensorType, \ from ray.rllib.utils.typing import ModelGradients, TensorType, \
TrainerConfigDict TrainerConfigDict
@ -53,8 +56,9 @@ class DynamicTFPolicy(TFPolicy):
obs_space: gym.spaces.Space, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, action_space: gym.spaces.Space,
config: TrainerConfigDict, config: TrainerConfigDict,
loss_fn: Callable[[Policy, ModelV2, type, SampleBatch], loss_fn: Callable[[
TensorType], Policy, ModelV2, Type[TFActionDistribution], SampleBatch
], TensorType],
*, *,
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[ stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
str, TensorType]]] = None, str, TensorType]]] = None,
@ -85,9 +89,9 @@ class DynamicTFPolicy(TFPolicy):
policy. policy.
action_space (gym.spaces.Space): Action space of the policy. action_space (gym.spaces.Space): Action space of the policy.
config (TrainerConfigDict): Policy-specific configuration data. config (TrainerConfigDict): Policy-specific configuration data.
loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], loss_fn (Callable[[Policy, ModelV2, Type[TFActionDistribution],
TensorType]): Function that returns a loss tensor for the SampleBatch], TensorType]): Function that returns a loss tensor
policy graph. for the policy graph.
stats_fn (Optional[Callable[[Policy, SampleBatch], stats_fn (Optional[Callable[[Policy, SampleBatch],
Dict[str, TensorType]]]): Optional function that returns a dict Dict[str, TensorType]]]): Optional function that returns a dict
of TF fetches given the policy and batch input tensors. of TF fetches given the policy and batch input tensors.
@ -128,9 +132,9 @@ class DynamicTFPolicy(TFPolicy):
placeholders to use instead of defining new ones. placeholders to use instead of defining new ones.
existing_model (Optional[ModelV2]): When copying a policy, this existing_model (Optional[ModelV2]): When copying a policy, this
specifies an existing model to clone and share weights with. specifies an existing model to clone and share weights with.
get_batch_divisibility_req (Optional[Callable[[Policy], int]]]): get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
Optional callable that returns the divisibility requirement Optional callable that returns the divisibility requirement for
for sample batches given the Policy. sample batches. If None, will assume a value of 1.
obs_include_prev_action_reward (bool): Whether to include the obs_include_prev_action_reward (bool): Whether to include the
previous action and reward in the model input (default: True). previous action and reward in the model input (default: True).
""" """
@ -262,10 +266,10 @@ class DynamicTFPolicy(TFPolicy):
# Phase 1 init. # Phase 1 init.
sess = tf1.get_default_session() or tf1.Session() sess = tf1.get_default_session() or tf1.Session()
if get_batch_divisibility_req:
batch_divisibility_req = get_batch_divisibility_req(self) batch_divisibility_req = get_batch_divisibility_req(self) if \
else: callable(get_batch_divisibility_req) else \
batch_divisibility_req = 1 (get_batch_divisibility_req or 1)
super().__init__( super().__init__(
observation_space=obs_space, observation_space=obs_space,
@ -353,6 +357,56 @@ class DynamicTFPolicy(TFPolicy):
else: else:
return [] return []
def _get_input_dict_and_dummy_batch(self, view_requirements,
existing_inputs):
"""Creates input_dict and dummy_batch for loss initialization.
Used for managing the Policy's input placeholders and for loss
initialization.
Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.
Args:
view_requirements (ViewReqs): The view requirements dict.
existing_inputs (Dict[str, tf.placeholder]): A dict of already
existing placeholders.
Returns:
Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
input_dict/dummy_batch tuple.
"""
input_dict = {}
dummy_batch = {}
for view_col, view_req in view_requirements.items():
# Point state_in to the already existing self._state_inputs.
mo = re.match("state_in_(\d+)", view_col)
if mo is not None:
input_dict[view_col] = self._state_inputs[int(mo.group(1))]
dummy_batch[view_col] = np.zeros_like(
[view_req.space.sample()])
# State-outs (no placeholders needed).
elif view_col.startswith("state_out_"):
dummy_batch[view_col] = np.zeros_like(
[view_req.space.sample()])
# Skip action dist inputs placeholder (do later).
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
continue
elif view_col in existing_inputs:
input_dict[view_col] = existing_inputs[view_col]
dummy_batch[view_col] = np.zeros(
shape=[
1 if s is None else s
for s in existing_inputs[view_col].shape.as_list()
],
dtype=np.float32)
# All others.
else:
if view_req.used_for_training:
input_dict[view_col] = get_placeholder(
space=view_req.space)
dummy_batch[view_col] = np.zeros_like(
[view_req.space.sample()])
return input_dict, dummy_batch
def _initialize_loss_dynamically(self): def _initialize_loss_dynamically(self):
def fake_array(tensor): def fake_array(tensor):
shape = tensor.shape.as_list() shape = tensor.shape.as_list()

View file

@ -16,6 +16,8 @@ from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
from ray.rllib.utils.tf_ops import convert_to_non_tf_type
from ray.rllib.utils.tracking_dict import UsageTrackingDict
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -273,7 +275,7 @@ def build_eager_tf_policy(name,
if before_loss_init: if before_loss_init:
before_loss_init(self, observation_space, action_space, config) before_loss_init(self, observation_space, action_space, config)
self._initialize_loss_with_dummy_batch() self._initialize_loss_from_dummy_batch()
self._loss_initialized = True self._loss_initialized = True
if optimizer_fn: if optimizer_fn:
@ -363,8 +365,8 @@ def build_eager_tf_policy(name,
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
"is_training": tf.constant(False), "is_training": tf.constant(False),
} }
n = input_dict[SampleBatch.CUR_OBS].shape[0] batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
seq_lens = tf.ones(n, dtype=tf.int32) seq_lens = tf.ones(batch_size, dtype=tf.int32)
if obs_include_prev_action_reward: if obs_include_prev_action_reward:
if prev_action_batch is not None: if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \ input_dict[SampleBatch.PREV_ACTIONS] = \
@ -425,8 +427,7 @@ def build_eager_tf_policy(name,
extra_fetches.update(extra_action_fetches_fn(self)) extra_fetches.update(extra_action_fetches_fn(self))
# Update our global timestep by the batch size. # Update our global timestep by the batch size.
self.global_timestep += len(obs_batch) if \ self.global_timestep += int(batch_size)
isinstance(obs_batch, (tuple, list)) else obs_batch.shape[0]
return actions, state_out, extra_fetches return actions, state_out, extra_fetches
@ -636,7 +637,8 @@ def build_eager_tf_policy(name,
}) })
return fetches return fetches
def _initialize_loss_with_dummy_batch(self): @override(Policy)
def _initialize_loss_from_dummy_batch(self):
# Dummy forward pass to initialize any policy attributes, etc. # Dummy forward pass to initialize any policy attributes, etc.
dummy_batch = { dummy_batch = {
SampleBatch.CUR_OBS: np.array( SampleBatch.CUR_OBS: np.array(
@ -711,6 +713,16 @@ def build_eager_tf_policy(name,
if stats_fn: if stats_fn:
stats_fn(self, postprocessed_batch) stats_fn(self, postprocessed_batch)
def _lazy_tensor_dict(self, postprocessed_batch):
train_batch = UsageTrackingDict(postprocessed_batch)
train_batch.set_get_interceptor(tf.convert_to_tensor)
return train_batch
def _lazy_numpy_dict(self, postprocessed_batch):
train_batch = UsageTrackingDict(postprocessed_batch)
train_batch.set_get_interceptor(convert_to_non_tf_type)
return train_batch
@classmethod @classmethod
def with_tracing(cls): def with_tracing(cls):
return traced_eager_policy(cls) return traced_eager_policy(cls)

View file

@ -1,9 +1,11 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import gym import gym
from gym.spaces import Box
import numpy as np import numpy as np
import tree import tree
from typing import Dict, List, Optional from typing import Dict, List, Optional
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import DeveloperAPI
@ -227,11 +229,13 @@ class Policy(metaclass=ABCMeta):
return single_action, [s[0] for s in state_out], \ return single_action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()} {k: v[0] for k, v in info.items()}
@DeveloperAPI
def compute_actions_from_input_dict( def compute_actions_from_input_dict(
self, self,
input_dict: Dict[str, TensorType], input_dict: Dict[str, TensorType],
explore: bool = None, explore: bool = None,
timestep: Optional[int] = None, timestep: Optional[int] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
**kwargs) -> \ **kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Computes actions from collected samples (across multiple-agents). """Computes actions from collected samples (across multiple-agents).
@ -278,6 +282,7 @@ class Policy(metaclass=ABCMeta):
info_batch=None, info_batch=None,
explore=explore, explore=explore,
timestep=timestep, timestep=timestep,
episodes=episodes,
**kwargs, **kwargs,
) )
@ -534,6 +539,162 @@ class Policy(metaclass=ABCMeta):
framework=getattr(self, "framework", "tf")) framework=getattr(self, "framework", "tf"))
return exploration return exploration
def _get_default_view_requirements(self):
"""Returns a default ViewRequirements dict.
Note: This is the base/maximum requirement dict, from which later
some requirements will be subtracted again automatically to streamline
data collection, batch creation, and data transfer.
Returns:
ViewReqDict: The default view requirements dict.
"""
# Default view requirements (equal to those that we would use before
# the trajectory view API was introduced).
return {
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
SampleBatch.NEXT_OBS: ViewRequirement(
data_col=SampleBatch.OBS,
shift=1,
space=self.observation_space),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
SampleBatch.INFOS: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
SampleBatch.AGENT_INDEX: ViewRequirement(),
"t": ViewRequirement(),
}
def _initialize_loss_from_dummy_batch(
self, auto_remove_unneeded_view_reqs: bool = True) -> None:
"""Performs test calls through policy's model and loss.
NOTE: This base method should work for define-by-run Policies such as
torch and tf-eager policies.
If required, will thereby detect automatically, which data views are
required by a) the forward pass, b) the postprocessing, and c) the loss
functions, and remove those from self.view_requirements that are not
necessary for these computations (to save data storage and transfer).
Args:
auto_remove_unneeded_view_reqs (bool): Whether to automatically
remove those ViewRequirements records from
self.view_requirements that are not needed.
"""
sample_batch_size = max(self.batch_divisibility_req, 2)
B = 2 # For RNNs, have B=2, T=[depends on sample_batch_size]
self._dummy_batch = self._get_dummy_batch_from_view_requirements(
sample_batch_size)
input_dict = self._lazy_tensor_dict(self._dummy_batch)
actions, state_outs, extra_outs = \
self.compute_actions_from_input_dict(input_dict)
# Add extra outs to view reqs.
for key, value in extra_outs.items():
self._dummy_batch[key] = np.zeros_like(value)
if key not in self.view_requirements:
self.view_requirements[key] = \
ViewRequirement(space=gym.spaces.Box(
-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype))
sb = SampleBatch(self._dummy_batch)
if state_outs:
# TODO: (sven) This hack will not work for attention net traj.
# view setup.
i = 0
while "state_in_{}".format(i) in sb:
sb["state_in_{}".format(i)] = sb["state_in_{}".format(i)][:B]
if "state_out_{}".format(i) in sb:
sb["state_out_{}".format(i)] = \
sb["state_out_{}".format(i)][:B]
i += 1
batch_for_postproc = self._lazy_numpy_dict(sb)
batch_for_postproc.count = sb.count
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
if state_outs:
seq_len = (self.batch_divisibility_req // B) or 1
postprocessed_batch["seq_lens"] = \
np.array([seq_len for _ in range(B)], dtype=np.int32)
train_batch = self._lazy_tensor_dict(postprocessed_batch)
if self._loss is not None:
self._loss(self, self.model, self.dist_class, train_batch)
# Add new columns automatically to view-reqs.
if self.config["_use_trajectory_view_api"] and \
auto_remove_unneeded_view_reqs:
# Add those needed for postprocessing and training.
all_accessed_keys = train_batch.accessed_keys | \
batch_for_postproc.accessed_keys | \
batch_for_postproc.added_keys
for key in all_accessed_keys:
if key not in self.view_requirements:
self.view_requirements[key] = ViewRequirement()
if self._loss:
# Tag those only needed for post-processing.
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys:
self.view_requirements[key].used_for_training = False
# Remove those not needed at all (leave those that are needed
# by Sampler to properly execute sample collection).
for key in list(self.view_requirements.keys()):
if key not in all_accessed_keys and key not in [
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
SampleBatch.UNROLL_ID, SampleBatch.DONES] and \
key not in self.model.inference_view_requirements:
del self.view_requirements[key]
# Add those data_cols (again) that are missing and have
# dependencies by view_cols.
for key in list(self.view_requirements.keys()):
vr = self.view_requirements[key]
if vr.data_col is not None and \
vr.data_col not in self.view_requirements:
used_for_training = \
vr.data_col in train_batch.accessed_keys
self.view_requirements[vr.data_col] = \
ViewRequirement(
space=vr.space,
used_for_training=used_for_training)
def _get_dummy_batch_from_view_requirements(
self, batch_size: int = 1) -> SampleBatch:
"""Creates a numpy dummy batch based on the Policy's view requirements.
Args:
batch_size (int): The size of the batch to create.
Returns:
Dict[str, TensorType]: The dummy batch containing all zero values.
"""
ret = {}
for view_col, view_req in self.view_requirements.items():
if isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)):
_, shape = ModelCatalog.get_action_shape(view_req.space)
ret[view_col] = \
np.zeros((batch_size, ) + shape[1:], np.float32)
else:
ret[view_col] = np.zeros_like(
[view_req.space.sample() for _ in range(batch_size)])
return SampleBatch(ret)
def _update_model_inference_view_requirements_from_init_state(self):
"""Uses this Model's initial state to auto-add necessary ViewReqs.
Can be called from within a Policy to make sure RNNs automatically
update their internal state-related view requirements.
Changes the `self.inference_view_requirements` dict.
"""
model = self.model
# Add state-ins to this model's view.
for i, state in enumerate(model.get_initial_state()):
model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift=-1,
space=Box(-1.0, 1.0, shape=state.shape))
model.inference_view_requirements["state_out_{}".format(i)] = \
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))
def clip_action(action, action_space): def clip_action(action, action_space):
"""Clips all actions in `flat_actions` according to the given Spaces. """Clips all actions in `flat_actions` according to the given Spaces.

View file

@ -63,7 +63,7 @@ def build_tf_policy(
], Tuple[TensorType, type, List[TensorType]]]] = None, ], Tuple[TensorType, type, List[TensorType]]]] = None,
mixins: Optional[List[type]] = None, mixins: Optional[List[type]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
obs_include_prev_action_reward: bool = True) -> Type[TFPolicy]: obs_include_prev_action_reward: bool = True) -> Type[DynamicTFPolicy]:
"""Helper function for creating a dynamic tf policy at runtime. """Helper function for creating a dynamic tf policy at runtime.
Functions will be run in this order to initialize the policy: Functions will be run in this order to initialize the policy:

View file

@ -591,6 +591,12 @@ class TorchPolicy(Policy):
functools.partial(convert_to_torch_tensor, device=self.device)) functools.partial(convert_to_torch_tensor, device=self.device))
return train_batch return train_batch
def _lazy_numpy_dict(self, postprocessed_batch):
train_batch = UsageTrackingDict(postprocessed_batch)
train_batch.set_get_interceptor(
functools.partial(convert_to_non_torch_type))
return train_batch
# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch) # TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch)
# and for all possible hyperparams, not just lr. # and for all possible hyperparams, not just lr.

View file

@ -48,8 +48,14 @@ def build_torch_policy(
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
before_init: Optional[Callable[ before_init: Optional[Callable[
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
before_loss_init: Optional[Callable[[
Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
], None]] = None,
after_init: Optional[Callable[ after_init: Optional[Callable[
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None, [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
_after_loss_init: Optional[Callable[[
Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
], None]] = None,
action_sampler_fn: Optional[Callable[[TensorType, List[ action_sampler_fn: Optional[Callable[[TensorType, List[
TensorType]], Tuple[TensorType, TensorType]]] = None, TensorType]], Tuple[TensorType, TensorType]]] = None,
action_distribution_fn: Optional[Callable[[ action_distribution_fn: Optional[Callable[[
@ -64,7 +70,7 @@ def build_torch_policy(
apply_gradients_fn: Optional[Callable[ apply_gradients_fn: Optional[Callable[
[Policy, "torch.optim.Optimizer"], None]] = None, [Policy, "torch.optim.Optimizer"], None]] = None,
mixins: Optional[List[type]] = None, mixins: Optional[List[type]] = None,
view_requirements_fn: Optional[Callable[[], Dict[ view_requirements_fn: Optional[Callable[[Policy], Dict[
str, ViewRequirement]]] = None, str, ViewRequirement]]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
) -> Type[TorchPolicy]: ) -> Type[TorchPolicy]:
@ -117,10 +123,17 @@ def build_torch_policy(
TrainerConfigDict], None]]): Optional callable to run at the TrainerConfigDict], None]]): Optional callable to run at the
beginning of `Policy.__init__` that takes the same arguments as beginning of `Policy.__init__` that takes the same arguments as
the Policy constructor. If None, this step will be skipped. the Policy constructor. If None, this step will be skipped.
before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
run prior to loss init. If None, this step will be skipped.
after_init (Optional[Callable[[Policy, gym.Space, gym.Space, after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
TrainerConfigDict], None]]): Optional callable to run at the end of TrainerConfigDict], None]]): DEPRECATED: Use `before_loss_init`
policy init that takes the same arguments as the policy instead.
constructor. If None, this step will be skipped. _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
run after the loss init. If None, this step will be skipped.
This will be deprecated at some point and renamed into `after_init`
to match `build_tf_policy()` behavior.
action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]], action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
Tuple[TensorType, TensorType]]]): Optional callable returning a Tuple[TensorType, TensorType]]]): Optional callable returning a
sampled action and its log-likelihood given some (obs and state) sampled action and its log-likelihood given some (obs and state)
@ -128,13 +141,13 @@ def build_torch_policy(
compute actions by calling self.model, then sampling from the compute actions by calling self.model, then sampling from the
so parameterized action distribution. so parameterized action distribution.
action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType, action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
TensorType, TensorType], Tuple[TensorType, type, TensorType, TensorType], Tuple[TensorType,
List[TensorType]]]]): A callable that takes Type[TorchDistributionWrapper], List[TensorType]]]]): A callable
the Policy, Model, the observation batch, an explore-flag, a that takes the Policy, Model, the observation batch, an
timestep, and an is_training flag and returns a tuple of explore-flag, a timestep, and an is_training flag and returns a
a) distribution inputs (parameters), b) a dist-class to generate tuple of a) distribution inputs (parameters), b) a dist-class to
an action distribution object from, and c) internal-state outputs generate an action distribution object from, and c) internal-state
(empty list if not applicable). If None, will either use outputs (empty list if not applicable). If None, will either use
`action_sampler_fn` or compute actions by calling self.model, `action_sampler_fn` or compute actions by calling self.model,
then sampling from the parameterized action distribution. then sampling from the parameterized action distribution.
make_model (Optional[Callable[[Policy, gym.spaces.Space, make_model (Optional[Callable[[Policy, gym.spaces.Space,

View file

@ -67,14 +67,15 @@ if __name__ == "__main__":
assert len(experiments) == 1,\ assert len(experiments) == 1,\
"Error, can only run a single experiment per yaml file!" "Error, can only run a single experiment per yaml file!"
print("== Test config ==")
print(yaml.dump(experiments))
# Add torch option to exp configs. # Add torch option to exp configs.
for exp in experiments.values(): for exp in experiments.values():
if args.torch: if args.torch:
exp["config"]["framework"] = "torch" exp["config"]["framework"] = "torch"
# Print out the actual config.
print("== Test config ==")
print(yaml.dump(experiments))
# Try running each test 3 times and make sure it reaches the given # Try running each test 3 times and make sure it reaches the given
# reward. # reward.
passed = False passed = False

View file

@ -35,10 +35,12 @@ def make_sample_batch(i):
class AgentIOTest(unittest.TestCase): class AgentIOTest(unittest.TestCase):
def setUp(self): def setUp(self):
ray.init(num_cpus=1, ignore_reinit_error=True)
self.test_dir = tempfile.mkdtemp() self.test_dir = tempfile.mkdtemp()
def tearDown(self): def tearDown(self):
shutil.rmtree(self.test_dir) shutil.rmtree(self.test_dir)
ray.shutdown()
def writeOutputs(self, output, fw): def writeOutputs(self, output, fw):
agent = PGTrainer( agent = PGTrainer(
@ -225,7 +227,7 @@ class AgentIOTest(unittest.TestCase):
class JsonIOTest(unittest.TestCase): class JsonIOTest(unittest.TestCase):
def setUp(self): def setUp(self):
ray.init(num_cpus=1) ray.init(num_cpus=1, ignore_reinit_error=True)
self.test_dir = tempfile.mkdtemp() self.test_dir = tempfile.mkdtemp()
def tearDown(self): def tearDown(self):

View file

@ -1,4 +1,4 @@
cartpole-dqn: cartpole-simpleq:
env: CartPole-v0 env: CartPole-v0
run: SimpleQ run: SimpleQ
stop: stop:

View file

@ -1,14 +1,61 @@
import gym
import numpy as np
import tree
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
def convert_to_non_tf_type(stats):
"""Converts values in `stats` to non-Tensor numpy or python types.
Args:
stats (any): Any (possibly nested) struct, the values in which will be
converted and returned as a new struct with all tf (eager) tensors
being converted to numpy types.
Returns:
Any: A new struct with the same structure as `stats`, but with all
values converted to non-tf Tensor types.
"""
# The mapping function used to numpyize torch Tensors.
def mapping(item):
if isinstance(item, (tf.Tensor, tf.Variable)):
return item.numpy()
else:
return item
return tree.map_structure(mapping, stats)
def explained_variance(y, pred): def explained_variance(y, pred):
_, y_var = tf.nn.moments(y, axes=[0]) _, y_var = tf.nn.moments(y, axes=[0])
_, diff_var = tf.nn.moments(y - pred, axes=[0]) _, diff_var = tf.nn.moments(y - pred, axes=[0])
return tf.maximum(-1.0, 1 - (diff_var / y_var)) return tf.maximum(-1.0, 1 - (diff_var / y_var))
def get_placeholder(*, space=None, value=None):
from ray.rllib.models.catalog import ModelCatalog
if space is not None:
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
return ModelCatalog.get_action_placeholder(space, None)
return tf1.placeholder(
shape=(None, ) + space.shape,
dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
)
else:
assert value is not None
shape = value.shape[1:]
return tf1.placeholder(
shape=(None, ) + (shape if isinstance(shape, tuple) else tuple(
shape.as_list())),
dtype=tf.float32 if value.dtype == np.float64 else value.dtype,
)
def huber_loss(x, delta=1.0): def huber_loss(x, delta=1.0):
"""Reference: https://en.wikipedia.org/wiki/Huber_loss""" """Reference: https://en.wikipedia.org/wiki/Huber_loss"""
return tf.where( return tf.where(

View file

@ -125,11 +125,11 @@ def minimize_and_clip(optimizer, clip_val=10):
def one_hot(x, space): def one_hot(x, space):
if isinstance(space, Discrete): if isinstance(space, Discrete):
return nn.functional.one_hot(x, space.n) return nn.functional.one_hot(x.long(), space.n)
elif isinstance(space, MultiDiscrete): elif isinstance(space, MultiDiscrete):
return torch.cat( return torch.cat(
[ [
nn.functional.one_hot(x[:, i], n) nn.functional.one_hot(x[:, i].long(), n)
for i, n in enumerate(space.nvec) for i, n in enumerate(space.nvec)
], ],
dim=-1) dim=-1)

View file

@ -11,6 +11,7 @@ class UsageTrackingDict(dict):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs) dict.__init__(self, *args, **kwargs)
self.accessed_keys = set() self.accessed_keys = set()
self.added_keys = set()
self.intercepted_values = {} self.intercepted_values = {}
self.get_interceptor = None self.get_interceptor = None
@ -32,6 +33,8 @@ class UsageTrackingDict(dict):
return value return value
def __setitem__(self, key, value): def __setitem__(self, key, value):
if key not in self:
self.added_keys.add(key)
dict.__setitem__(self, key, value) dict.__setitem__(self, key, value)
if key in self.intercepted_values: if key in self.intercepted_values:
self.intercepted_values[key] = value self.intercepted_values[key] = value