mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Trajectory view API (prep PR for switching on by default across all RLlib; plumbing only) (#11717)
This commit is contained in:
parent
c3074f559c
commit
5b788ccb13
15 changed files with 364 additions and 61 deletions
|
@ -136,10 +136,11 @@ class _AgentCollector:
|
|||
if data_col not in np_data:
|
||||
np_data[data_col] = to_float_np_array(self.buffers[data_col])
|
||||
if shift == 0:
|
||||
batch_data[view_col] = np_data[data_col][self.shift_before:]
|
||||
data = np_data[data_col][self.shift_before:]
|
||||
else:
|
||||
batch_data[view_col] = np_data[data_col][self.shift_before +
|
||||
shift:shift]
|
||||
data = np_data[data_col][self.shift_before + shift:shift]
|
||||
if len(data) > 0:
|
||||
batch_data[view_col] = data
|
||||
batch = SampleBatch(batch_data)
|
||||
|
||||
if SampleBatch.UNROLL_ID not in batch.data:
|
||||
|
@ -340,7 +341,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
assert self.agent_key_to_policy[agent_key] == policy_id
|
||||
policy = self.policy_map[policy_id]
|
||||
view_reqs = policy.model.inference_view_requirements if \
|
||||
hasattr(policy, "model") else policy.view_requirements
|
||||
getattr(policy, "model", None) else policy.view_requirements
|
||||
|
||||
# Add initial obs to Trajectory.
|
||||
assert agent_key not in self.agent_collectors
|
||||
|
@ -388,7 +389,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
keys = self.forward_pass_agent_keys[policy_id]
|
||||
buffers = {k: self.agent_collectors[k].buffers for k in keys}
|
||||
view_reqs = policy.model.inference_view_requirements if \
|
||||
hasattr(policy, "model") else policy.view_requirements
|
||||
getattr(policy, "model", None) else policy.view_requirements
|
||||
|
||||
input_dict = {}
|
||||
for view_col, view_req in view_reqs.items():
|
||||
|
@ -447,19 +448,19 @@ class _SimpleListCollector(_SampleCollector):
|
|||
|
||||
for agent_id, (_, pre_batch) in pre_batches.items():
|
||||
# Entire episode is said to be done.
|
||||
if is_done:
|
||||
# Error if no DONE at end of this agent's trajectory.
|
||||
if check_dones and not pre_batch[SampleBatch.DONES][-1]:
|
||||
raise ValueError(
|
||||
"Episode {} terminated for all agents, but we still "
|
||||
"don't have a last observation for agent {} (policy "
|
||||
"{}). ".format(
|
||||
episode_id, agent_id, self.agent_key_to_policy[(
|
||||
episode_id, agent_id)]) +
|
||||
"Please ensure that you include the last observations "
|
||||
"of all live agents when setting done[__all__] to "
|
||||
"True. Alternatively, set no_done_at_end=True to "
|
||||
"allow this.")
|
||||
# Error if no DONE at end of this agent's trajectory.
|
||||
if is_done and check_dones and \
|
||||
not pre_batch[SampleBatch.DONES][-1]:
|
||||
raise ValueError(
|
||||
"Episode {} terminated for all agents, but we still don't "
|
||||
"don't have a last observation for agent {} (policy "
|
||||
"{}). ".format(
|
||||
episode_id, agent_id, self.agent_key_to_policy[(
|
||||
episode_id, agent_id)]) +
|
||||
"Please ensure that you include the last observations "
|
||||
"of all live agents when setting done[__all__] to "
|
||||
"True. Alternatively, set no_done_at_end=True to "
|
||||
"allow this.")
|
||||
# If (only this?) agent is done, erase its buffer entirely.
|
||||
if pre_batch[SampleBatch.DONES][-1]:
|
||||
del self.agent_collectors[(episode_id, agent_id)]
|
||||
|
|
|
@ -364,7 +364,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.clip_actions, self.multiple_episodes_in_batch, self.callbacks,
|
||||
self.tf_sess, self.perf_stats, self.soft_horizon,
|
||||
self.no_done_at_end, self.observation_fn,
|
||||
self._use_trajectory_view_api)
|
||||
self._use_trajectory_view_api, self.sample_collector)
|
||||
while not self.shutdown:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
|
@ -613,6 +613,7 @@ def _env_runner(
|
|||
to_eval=to_eval,
|
||||
policies=policies,
|
||||
_sample_collector=_sample_collector,
|
||||
active_episodes=active_episodes,
|
||||
tf_sess=tf_sess,
|
||||
)
|
||||
else:
|
||||
|
@ -1252,7 +1253,8 @@ def _do_policy_eval_w_trajectory_view_api(
|
|||
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
||||
policies: Dict[PolicyID, Policy],
|
||||
_sample_collector,
|
||||
tf_sess=None,
|
||||
active_episodes: Dict[str, MultiAgentEpisode],
|
||||
tf_sess: Optional["tf.Session"] = None,
|
||||
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
||||
"""Call compute_actions on collected episode/model data to get next action.
|
||||
|
||||
|
@ -1282,12 +1284,14 @@ def _do_policy_eval_w_trajectory_view_api(
|
|||
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
||||
summarize(to_eval)))
|
||||
|
||||
for policy_id in to_eval.keys():
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
policy: Policy = _get_or_raise(policies, policy_id)
|
||||
input_dict = _sample_collector.get_inference_input_dict(policy_id)
|
||||
eval_results[policy_id] = \
|
||||
policy.compute_actions_from_input_dict(
|
||||
input_dict, timestep=policy.global_timestep)
|
||||
input_dict,
|
||||
timestep=policy.global_timestep,
|
||||
episodes=[active_episodes[t.env_id] for t in eval_data])
|
||||
|
||||
if builder:
|
||||
# type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
|
||||
|
|
|
@ -58,7 +58,6 @@ class Preprocessor:
|
|||
observation = np.array(observation)
|
||||
try:
|
||||
if not self._obs_space.contains(observation):
|
||||
print()
|
||||
raise ValueError(
|
||||
"Observation ({}) outside given space ({})!",
|
||||
observation, self._obs_space)
|
||||
|
|
|
@ -2,10 +2,12 @@ from collections import OrderedDict
|
|||
import gym
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
import re
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
|
@ -13,6 +15,7 @@ from ray.rllib.models.catalog import ModelCatalog
|
|||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import get_placeholder
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
from ray.rllib.utils.typing import ModelGradients, TensorType, \
|
||||
TrainerConfigDict
|
||||
|
@ -53,8 +56,9 @@ class DynamicTFPolicy(TFPolicy):
|
|||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
loss_fn: Callable[[Policy, ModelV2, type, SampleBatch],
|
||||
TensorType],
|
||||
loss_fn: Callable[[
|
||||
Policy, ModelV2, Type[TFActionDistribution], SampleBatch
|
||||
], TensorType],
|
||||
*,
|
||||
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
|
||||
str, TensorType]]] = None,
|
||||
|
@ -85,9 +89,9 @@ class DynamicTFPolicy(TFPolicy):
|
|||
policy.
|
||||
action_space (gym.spaces.Space): Action space of the policy.
|
||||
config (TrainerConfigDict): Policy-specific configuration data.
|
||||
loss_fn (Callable[[Policy, ModelV2, type, SampleBatch],
|
||||
TensorType]): Function that returns a loss tensor for the
|
||||
policy graph.
|
||||
loss_fn (Callable[[Policy, ModelV2, Type[TFActionDistribution],
|
||||
SampleBatch], TensorType]): Function that returns a loss tensor
|
||||
for the policy graph.
|
||||
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
||||
Dict[str, TensorType]]]): Optional function that returns a dict
|
||||
of TF fetches given the policy and batch input tensors.
|
||||
|
@ -128,9 +132,9 @@ class DynamicTFPolicy(TFPolicy):
|
|||
placeholders to use instead of defining new ones.
|
||||
existing_model (Optional[ModelV2]): When copying a policy, this
|
||||
specifies an existing model to clone and share weights with.
|
||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
|
||||
Optional callable that returns the divisibility requirement
|
||||
for sample batches given the Policy.
|
||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
|
||||
Optional callable that returns the divisibility requirement for
|
||||
sample batches. If None, will assume a value of 1.
|
||||
obs_include_prev_action_reward (bool): Whether to include the
|
||||
previous action and reward in the model input (default: True).
|
||||
"""
|
||||
|
@ -262,10 +266,10 @@ class DynamicTFPolicy(TFPolicy):
|
|||
|
||||
# Phase 1 init.
|
||||
sess = tf1.get_default_session() or tf1.Session()
|
||||
if get_batch_divisibility_req:
|
||||
batch_divisibility_req = get_batch_divisibility_req(self)
|
||||
else:
|
||||
batch_divisibility_req = 1
|
||||
|
||||
batch_divisibility_req = get_batch_divisibility_req(self) if \
|
||||
callable(get_batch_divisibility_req) else \
|
||||
(get_batch_divisibility_req or 1)
|
||||
|
||||
super().__init__(
|
||||
observation_space=obs_space,
|
||||
|
@ -353,6 +357,56 @@ class DynamicTFPolicy(TFPolicy):
|
|||
else:
|
||||
return []
|
||||
|
||||
def _get_input_dict_and_dummy_batch(self, view_requirements,
|
||||
existing_inputs):
|
||||
"""Creates input_dict and dummy_batch for loss initialization.
|
||||
|
||||
Used for managing the Policy's input placeholders and for loss
|
||||
initialization.
|
||||
Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.
|
||||
|
||||
Args:
|
||||
view_requirements (ViewReqs): The view requirements dict.
|
||||
existing_inputs (Dict[str, tf.placeholder]): A dict of already
|
||||
existing placeholders.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
|
||||
input_dict/dummy_batch tuple.
|
||||
"""
|
||||
input_dict = {}
|
||||
dummy_batch = {}
|
||||
for view_col, view_req in view_requirements.items():
|
||||
# Point state_in to the already existing self._state_inputs.
|
||||
mo = re.match("state_in_(\d+)", view_col)
|
||||
if mo is not None:
|
||||
input_dict[view_col] = self._state_inputs[int(mo.group(1))]
|
||||
dummy_batch[view_col] = np.zeros_like(
|
||||
[view_req.space.sample()])
|
||||
# State-outs (no placeholders needed).
|
||||
elif view_col.startswith("state_out_"):
|
||||
dummy_batch[view_col] = np.zeros_like(
|
||||
[view_req.space.sample()])
|
||||
# Skip action dist inputs placeholder (do later).
|
||||
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
|
||||
continue
|
||||
elif view_col in existing_inputs:
|
||||
input_dict[view_col] = existing_inputs[view_col]
|
||||
dummy_batch[view_col] = np.zeros(
|
||||
shape=[
|
||||
1 if s is None else s
|
||||
for s in existing_inputs[view_col].shape.as_list()
|
||||
],
|
||||
dtype=np.float32)
|
||||
# All others.
|
||||
else:
|
||||
if view_req.used_for_training:
|
||||
input_dict[view_col] = get_placeholder(
|
||||
space=view_req.space)
|
||||
dummy_batch[view_col] = np.zeros_like(
|
||||
[view_req.space.sample()])
|
||||
return input_dict, dummy_batch
|
||||
|
||||
def _initialize_loss_dynamically(self):
|
||||
def fake_array(tensor):
|
||||
shape = tensor.shape.as_list()
|
||||
|
|
|
@ -16,6 +16,8 @@ from ray.rllib.utils import add_mixins
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
|
||||
from ray.rllib.utils.tf_ops import convert_to_non_tf_type
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -273,7 +275,7 @@ def build_eager_tf_policy(name,
|
|||
if before_loss_init:
|
||||
before_loss_init(self, observation_space, action_space, config)
|
||||
|
||||
self._initialize_loss_with_dummy_batch()
|
||||
self._initialize_loss_from_dummy_batch()
|
||||
self._loss_initialized = True
|
||||
|
||||
if optimizer_fn:
|
||||
|
@ -363,8 +365,8 @@ def build_eager_tf_policy(name,
|
|||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
||||
"is_training": tf.constant(False),
|
||||
}
|
||||
n = input_dict[SampleBatch.CUR_OBS].shape[0]
|
||||
seq_lens = tf.ones(n, dtype=tf.int32)
|
||||
batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
|
||||
seq_lens = tf.ones(batch_size, dtype=tf.int32)
|
||||
if obs_include_prev_action_reward:
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||
|
@ -425,8 +427,7 @@ def build_eager_tf_policy(name,
|
|||
extra_fetches.update(extra_action_fetches_fn(self))
|
||||
|
||||
# Update our global timestep by the batch size.
|
||||
self.global_timestep += len(obs_batch) if \
|
||||
isinstance(obs_batch, (tuple, list)) else obs_batch.shape[0]
|
||||
self.global_timestep += int(batch_size)
|
||||
|
||||
return actions, state_out, extra_fetches
|
||||
|
||||
|
@ -636,7 +637,8 @@ def build_eager_tf_policy(name,
|
|||
})
|
||||
return fetches
|
||||
|
||||
def _initialize_loss_with_dummy_batch(self):
|
||||
@override(Policy)
|
||||
def _initialize_loss_from_dummy_batch(self):
|
||||
# Dummy forward pass to initialize any policy attributes, etc.
|
||||
dummy_batch = {
|
||||
SampleBatch.CUR_OBS: np.array(
|
||||
|
@ -711,6 +713,16 @@ def build_eager_tf_policy(name,
|
|||
if stats_fn:
|
||||
stats_fn(self, postprocessed_batch)
|
||||
|
||||
def _lazy_tensor_dict(self, postprocessed_batch):
|
||||
train_batch = UsageTrackingDict(postprocessed_batch)
|
||||
train_batch.set_get_interceptor(tf.convert_to_tensor)
|
||||
return train_batch
|
||||
|
||||
def _lazy_numpy_dict(self, postprocessed_batch):
|
||||
train_batch = UsageTrackingDict(postprocessed_batch)
|
||||
train_batch.set_get_interceptor(convert_to_non_tf_type)
|
||||
return train_batch
|
||||
|
||||
@classmethod
|
||||
def with_tracing(cls):
|
||||
return traced_eager_policy(cls)
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
import gym
|
||||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
import tree
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
@ -227,11 +229,13 @@ class Policy(metaclass=ABCMeta):
|
|||
return single_action, [s[0] for s in state_out], \
|
||||
{k: v[0] for k, v in info.items()}
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_actions_from_input_dict(
|
||||
self,
|
||||
input_dict: Dict[str, TensorType],
|
||||
explore: bool = None,
|
||||
timestep: Optional[int] = None,
|
||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
"""Computes actions from collected samples (across multiple-agents).
|
||||
|
@ -278,6 +282,7 @@ class Policy(metaclass=ABCMeta):
|
|||
info_batch=None,
|
||||
explore=explore,
|
||||
timestep=timestep,
|
||||
episodes=episodes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -534,6 +539,162 @@ class Policy(metaclass=ABCMeta):
|
|||
framework=getattr(self, "framework", "tf"))
|
||||
return exploration
|
||||
|
||||
def _get_default_view_requirements(self):
|
||||
"""Returns a default ViewRequirements dict.
|
||||
|
||||
Note: This is the base/maximum requirement dict, from which later
|
||||
some requirements will be subtracted again automatically to streamline
|
||||
data collection, batch creation, and data transfer.
|
||||
|
||||
Returns:
|
||||
ViewReqDict: The default view requirements dict.
|
||||
"""
|
||||
|
||||
# Default view requirements (equal to those that we would use before
|
||||
# the trajectory view API was introduced).
|
||||
return {
|
||||
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
|
||||
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||
data_col=SampleBatch.OBS,
|
||||
shift=1,
|
||||
space=self.observation_space),
|
||||
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||
SampleBatch.REWARDS: ViewRequirement(),
|
||||
SampleBatch.DONES: ViewRequirement(),
|
||||
SampleBatch.INFOS: ViewRequirement(),
|
||||
SampleBatch.EPS_ID: ViewRequirement(),
|
||||
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||
"t": ViewRequirement(),
|
||||
}
|
||||
|
||||
def _initialize_loss_from_dummy_batch(
|
||||
self, auto_remove_unneeded_view_reqs: bool = True) -> None:
|
||||
"""Performs test calls through policy's model and loss.
|
||||
|
||||
NOTE: This base method should work for define-by-run Policies such as
|
||||
torch and tf-eager policies.
|
||||
|
||||
If required, will thereby detect automatically, which data views are
|
||||
required by a) the forward pass, b) the postprocessing, and c) the loss
|
||||
functions, and remove those from self.view_requirements that are not
|
||||
necessary for these computations (to save data storage and transfer).
|
||||
|
||||
Args:
|
||||
auto_remove_unneeded_view_reqs (bool): Whether to automatically
|
||||
remove those ViewRequirements records from
|
||||
self.view_requirements that are not needed.
|
||||
"""
|
||||
sample_batch_size = max(self.batch_divisibility_req, 2)
|
||||
B = 2 # For RNNs, have B=2, T=[depends on sample_batch_size]
|
||||
self._dummy_batch = self._get_dummy_batch_from_view_requirements(
|
||||
sample_batch_size)
|
||||
input_dict = self._lazy_tensor_dict(self._dummy_batch)
|
||||
actions, state_outs, extra_outs = \
|
||||
self.compute_actions_from_input_dict(input_dict)
|
||||
# Add extra outs to view reqs.
|
||||
for key, value in extra_outs.items():
|
||||
self._dummy_batch[key] = np.zeros_like(value)
|
||||
if key not in self.view_requirements:
|
||||
self.view_requirements[key] = \
|
||||
ViewRequirement(space=gym.spaces.Box(
|
||||
-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype))
|
||||
sb = SampleBatch(self._dummy_batch)
|
||||
if state_outs:
|
||||
# TODO: (sven) This hack will not work for attention net traj.
|
||||
# view setup.
|
||||
i = 0
|
||||
while "state_in_{}".format(i) in sb:
|
||||
sb["state_in_{}".format(i)] = sb["state_in_{}".format(i)][:B]
|
||||
if "state_out_{}".format(i) in sb:
|
||||
sb["state_out_{}".format(i)] = \
|
||||
sb["state_out_{}".format(i)][:B]
|
||||
i += 1
|
||||
batch_for_postproc = self._lazy_numpy_dict(sb)
|
||||
batch_for_postproc.count = sb.count
|
||||
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
|
||||
if state_outs:
|
||||
seq_len = (self.batch_divisibility_req // B) or 1
|
||||
postprocessed_batch["seq_lens"] = \
|
||||
np.array([seq_len for _ in range(B)], dtype=np.int32)
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
if self._loss is not None:
|
||||
self._loss(self, self.model, self.dist_class, train_batch)
|
||||
|
||||
# Add new columns automatically to view-reqs.
|
||||
if self.config["_use_trajectory_view_api"] and \
|
||||
auto_remove_unneeded_view_reqs:
|
||||
# Add those needed for postprocessing and training.
|
||||
all_accessed_keys = train_batch.accessed_keys | \
|
||||
batch_for_postproc.accessed_keys | \
|
||||
batch_for_postproc.added_keys
|
||||
for key in all_accessed_keys:
|
||||
if key not in self.view_requirements:
|
||||
self.view_requirements[key] = ViewRequirement()
|
||||
if self._loss:
|
||||
# Tag those only needed for post-processing.
|
||||
for key in batch_for_postproc.accessed_keys:
|
||||
if key not in train_batch.accessed_keys:
|
||||
self.view_requirements[key].used_for_training = False
|
||||
# Remove those not needed at all (leave those that are needed
|
||||
# by Sampler to properly execute sample collection).
|
||||
for key in list(self.view_requirements.keys()):
|
||||
if key not in all_accessed_keys and key not in [
|
||||
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
||||
SampleBatch.UNROLL_ID, SampleBatch.DONES] and \
|
||||
key not in self.model.inference_view_requirements:
|
||||
del self.view_requirements[key]
|
||||
# Add those data_cols (again) that are missing and have
|
||||
# dependencies by view_cols.
|
||||
for key in list(self.view_requirements.keys()):
|
||||
vr = self.view_requirements[key]
|
||||
if vr.data_col is not None and \
|
||||
vr.data_col not in self.view_requirements:
|
||||
used_for_training = \
|
||||
vr.data_col in train_batch.accessed_keys
|
||||
self.view_requirements[vr.data_col] = \
|
||||
ViewRequirement(
|
||||
space=vr.space,
|
||||
used_for_training=used_for_training)
|
||||
|
||||
def _get_dummy_batch_from_view_requirements(
|
||||
self, batch_size: int = 1) -> SampleBatch:
|
||||
"""Creates a numpy dummy batch based on the Policy's view requirements.
|
||||
|
||||
Args:
|
||||
batch_size (int): The size of the batch to create.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The dummy batch containing all zero values.
|
||||
"""
|
||||
ret = {}
|
||||
for view_col, view_req in self.view_requirements.items():
|
||||
if isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
||||
_, shape = ModelCatalog.get_action_shape(view_req.space)
|
||||
ret[view_col] = \
|
||||
np.zeros((batch_size, ) + shape[1:], np.float32)
|
||||
else:
|
||||
ret[view_col] = np.zeros_like(
|
||||
[view_req.space.sample() for _ in range(batch_size)])
|
||||
return SampleBatch(ret)
|
||||
|
||||
def _update_model_inference_view_requirements_from_init_state(self):
|
||||
"""Uses this Model's initial state to auto-add necessary ViewReqs.
|
||||
|
||||
Can be called from within a Policy to make sure RNNs automatically
|
||||
update their internal state-related view requirements.
|
||||
Changes the `self.inference_view_requirements` dict.
|
||||
"""
|
||||
model = self.model
|
||||
# Add state-ins to this model's view.
|
||||
for i, state in enumerate(model.get_initial_state()):
|
||||
model.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift=-1,
|
||||
space=Box(-1.0, 1.0, shape=state.shape))
|
||||
model.inference_view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))
|
||||
|
||||
|
||||
def clip_action(action, action_space):
|
||||
"""Clips all actions in `flat_actions` according to the given Spaces.
|
||||
|
|
|
@ -63,7 +63,7 @@ def build_tf_policy(
|
|||
], Tuple[TensorType, type, List[TensorType]]]] = None,
|
||||
mixins: Optional[List[type]] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
|
||||
obs_include_prev_action_reward: bool = True) -> Type[TFPolicy]:
|
||||
obs_include_prev_action_reward: bool = True) -> Type[DynamicTFPolicy]:
|
||||
"""Helper function for creating a dynamic tf policy at runtime.
|
||||
|
||||
Functions will be run in this order to initialize the policy:
|
||||
|
|
|
@ -591,6 +591,12 @@ class TorchPolicy(Policy):
|
|||
functools.partial(convert_to_torch_tensor, device=self.device))
|
||||
return train_batch
|
||||
|
||||
def _lazy_numpy_dict(self, postprocessed_batch):
|
||||
train_batch = UsageTrackingDict(postprocessed_batch)
|
||||
train_batch.set_get_interceptor(
|
||||
functools.partial(convert_to_non_torch_type))
|
||||
return train_batch
|
||||
|
||||
|
||||
# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch)
|
||||
# and for all possible hyperparams, not just lr.
|
||||
|
|
|
@ -48,8 +48,14 @@ def build_torch_policy(
|
|||
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
|
||||
before_init: Optional[Callable[
|
||||
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
|
||||
before_loss_init: Optional[Callable[[
|
||||
Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
|
||||
], None]] = None,
|
||||
after_init: Optional[Callable[
|
||||
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
|
||||
_after_loss_init: Optional[Callable[[
|
||||
Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
|
||||
], None]] = None,
|
||||
action_sampler_fn: Optional[Callable[[TensorType, List[
|
||||
TensorType]], Tuple[TensorType, TensorType]]] = None,
|
||||
action_distribution_fn: Optional[Callable[[
|
||||
|
@ -64,7 +70,7 @@ def build_torch_policy(
|
|||
apply_gradients_fn: Optional[Callable[
|
||||
[Policy, "torch.optim.Optimizer"], None]] = None,
|
||||
mixins: Optional[List[type]] = None,
|
||||
view_requirements_fn: Optional[Callable[[], Dict[
|
||||
view_requirements_fn: Optional[Callable[[Policy], Dict[
|
||||
str, ViewRequirement]]] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
|
||||
) -> Type[TorchPolicy]:
|
||||
|
@ -117,10 +123,17 @@ def build_torch_policy(
|
|||
TrainerConfigDict], None]]): Optional callable to run at the
|
||||
beginning of `Policy.__init__` that takes the same arguments as
|
||||
the Policy constructor. If None, this step will be skipped.
|
||||
before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
|
||||
gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
|
||||
run prior to loss init. If None, this step will be skipped.
|
||||
after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
||||
TrainerConfigDict], None]]): Optional callable to run at the end of
|
||||
policy init that takes the same arguments as the policy
|
||||
constructor. If None, this step will be skipped.
|
||||
TrainerConfigDict], None]]): DEPRECATED: Use `before_loss_init`
|
||||
instead.
|
||||
_after_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
|
||||
gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
|
||||
run after the loss init. If None, this step will be skipped.
|
||||
This will be deprecated at some point and renamed into `after_init`
|
||||
to match `build_tf_policy()` behavior.
|
||||
action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
|
||||
Tuple[TensorType, TensorType]]]): Optional callable returning a
|
||||
sampled action and its log-likelihood given some (obs and state)
|
||||
|
@ -128,13 +141,13 @@ def build_torch_policy(
|
|||
compute actions by calling self.model, then sampling from the
|
||||
so parameterized action distribution.
|
||||
action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
|
||||
TensorType, TensorType], Tuple[TensorType, type,
|
||||
List[TensorType]]]]): A callable that takes
|
||||
the Policy, Model, the observation batch, an explore-flag, a
|
||||
timestep, and an is_training flag and returns a tuple of
|
||||
a) distribution inputs (parameters), b) a dist-class to generate
|
||||
an action distribution object from, and c) internal-state outputs
|
||||
(empty list if not applicable). If None, will either use
|
||||
TensorType, TensorType], Tuple[TensorType,
|
||||
Type[TorchDistributionWrapper], List[TensorType]]]]): A callable
|
||||
that takes the Policy, Model, the observation batch, an
|
||||
explore-flag, a timestep, and an is_training flag and returns a
|
||||
tuple of a) distribution inputs (parameters), b) a dist-class to
|
||||
generate an action distribution object from, and c) internal-state
|
||||
outputs (empty list if not applicable). If None, will either use
|
||||
`action_sampler_fn` or compute actions by calling self.model,
|
||||
then sampling from the parameterized action distribution.
|
||||
make_model (Optional[Callable[[Policy, gym.spaces.Space,
|
||||
|
|
|
@ -67,14 +67,15 @@ if __name__ == "__main__":
|
|||
assert len(experiments) == 1,\
|
||||
"Error, can only run a single experiment per yaml file!"
|
||||
|
||||
print("== Test config ==")
|
||||
print(yaml.dump(experiments))
|
||||
|
||||
# Add torch option to exp configs.
|
||||
for exp in experiments.values():
|
||||
if args.torch:
|
||||
exp["config"]["framework"] = "torch"
|
||||
|
||||
# Print out the actual config.
|
||||
print("== Test config ==")
|
||||
print(yaml.dump(experiments))
|
||||
|
||||
# Try running each test 3 times and make sure it reaches the given
|
||||
# reward.
|
||||
passed = False
|
||||
|
|
|
@ -35,10 +35,12 @@ def make_sample_batch(i):
|
|||
|
||||
class AgentIOTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=1, ignore_reinit_error=True)
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
ray.shutdown()
|
||||
|
||||
def writeOutputs(self, output, fw):
|
||||
agent = PGTrainer(
|
||||
|
@ -225,7 +227,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
|
||||
class JsonIOTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=1)
|
||||
ray.init(num_cpus=1, ignore_reinit_error=True)
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
cartpole-dqn:
|
||||
cartpole-simpleq:
|
||||
env: CartPole-v0
|
||||
run: SimpleQ
|
||||
stop:
|
||||
|
|
|
@ -1,14 +1,61 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import tree
|
||||
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
def convert_to_non_tf_type(stats):
|
||||
"""Converts values in `stats` to non-Tensor numpy or python types.
|
||||
|
||||
Args:
|
||||
stats (any): Any (possibly nested) struct, the values in which will be
|
||||
converted and returned as a new struct with all tf (eager) tensors
|
||||
being converted to numpy types.
|
||||
|
||||
Returns:
|
||||
Any: A new struct with the same structure as `stats`, but with all
|
||||
values converted to non-tf Tensor types.
|
||||
"""
|
||||
|
||||
# The mapping function used to numpyize torch Tensors.
|
||||
def mapping(item):
|
||||
if isinstance(item, (tf.Tensor, tf.Variable)):
|
||||
return item.numpy()
|
||||
else:
|
||||
return item
|
||||
|
||||
return tree.map_structure(mapping, stats)
|
||||
|
||||
|
||||
def explained_variance(y, pred):
|
||||
_, y_var = tf.nn.moments(y, axes=[0])
|
||||
_, diff_var = tf.nn.moments(y - pred, axes=[0])
|
||||
return tf.maximum(-1.0, 1 - (diff_var / y_var))
|
||||
|
||||
|
||||
def get_placeholder(*, space=None, value=None):
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
if space is not None:
|
||||
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
||||
return ModelCatalog.get_action_placeholder(space, None)
|
||||
return tf1.placeholder(
|
||||
shape=(None, ) + space.shape,
|
||||
dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
|
||||
)
|
||||
else:
|
||||
assert value is not None
|
||||
shape = value.shape[1:]
|
||||
return tf1.placeholder(
|
||||
shape=(None, ) + (shape if isinstance(shape, tuple) else tuple(
|
||||
shape.as_list())),
|
||||
dtype=tf.float32 if value.dtype == np.float64 else value.dtype,
|
||||
)
|
||||
|
||||
|
||||
def huber_loss(x, delta=1.0):
|
||||
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
||||
return tf.where(
|
||||
|
|
|
@ -125,11 +125,11 @@ def minimize_and_clip(optimizer, clip_val=10):
|
|||
|
||||
def one_hot(x, space):
|
||||
if isinstance(space, Discrete):
|
||||
return nn.functional.one_hot(x, space.n)
|
||||
return nn.functional.one_hot(x.long(), space.n)
|
||||
elif isinstance(space, MultiDiscrete):
|
||||
return torch.cat(
|
||||
[
|
||||
nn.functional.one_hot(x[:, i], n)
|
||||
nn.functional.one_hot(x[:, i].long(), n)
|
||||
for i, n in enumerate(space.nvec)
|
||||
],
|
||||
dim=-1)
|
||||
|
|
|
@ -11,6 +11,7 @@ class UsageTrackingDict(dict):
|
|||
def __init__(self, *args, **kwargs):
|
||||
dict.__init__(self, *args, **kwargs)
|
||||
self.accessed_keys = set()
|
||||
self.added_keys = set()
|
||||
self.intercepted_values = {}
|
||||
self.get_interceptor = None
|
||||
|
||||
|
@ -32,6 +33,8 @@ class UsageTrackingDict(dict):
|
|||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key not in self:
|
||||
self.added_keys.add(key)
|
||||
dict.__setitem__(self, key, value)
|
||||
if key in self.intercepted_values:
|
||||
self.intercepted_values[key] = value
|
||||
|
|
Loading…
Add table
Reference in a new issue