mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Trajectory view API (prep PR for switching on by default across all RLlib; plumbing only) (#11717)
This commit is contained in:
parent
c3074f559c
commit
5b788ccb13
15 changed files with 364 additions and 61 deletions
|
@ -136,10 +136,11 @@ class _AgentCollector:
|
||||||
if data_col not in np_data:
|
if data_col not in np_data:
|
||||||
np_data[data_col] = to_float_np_array(self.buffers[data_col])
|
np_data[data_col] = to_float_np_array(self.buffers[data_col])
|
||||||
if shift == 0:
|
if shift == 0:
|
||||||
batch_data[view_col] = np_data[data_col][self.shift_before:]
|
data = np_data[data_col][self.shift_before:]
|
||||||
else:
|
else:
|
||||||
batch_data[view_col] = np_data[data_col][self.shift_before +
|
data = np_data[data_col][self.shift_before + shift:shift]
|
||||||
shift:shift]
|
if len(data) > 0:
|
||||||
|
batch_data[view_col] = data
|
||||||
batch = SampleBatch(batch_data)
|
batch = SampleBatch(batch_data)
|
||||||
|
|
||||||
if SampleBatch.UNROLL_ID not in batch.data:
|
if SampleBatch.UNROLL_ID not in batch.data:
|
||||||
|
@ -340,7 +341,7 @@ class _SimpleListCollector(_SampleCollector):
|
||||||
assert self.agent_key_to_policy[agent_key] == policy_id
|
assert self.agent_key_to_policy[agent_key] == policy_id
|
||||||
policy = self.policy_map[policy_id]
|
policy = self.policy_map[policy_id]
|
||||||
view_reqs = policy.model.inference_view_requirements if \
|
view_reqs = policy.model.inference_view_requirements if \
|
||||||
hasattr(policy, "model") else policy.view_requirements
|
getattr(policy, "model", None) else policy.view_requirements
|
||||||
|
|
||||||
# Add initial obs to Trajectory.
|
# Add initial obs to Trajectory.
|
||||||
assert agent_key not in self.agent_collectors
|
assert agent_key not in self.agent_collectors
|
||||||
|
@ -388,7 +389,7 @@ class _SimpleListCollector(_SampleCollector):
|
||||||
keys = self.forward_pass_agent_keys[policy_id]
|
keys = self.forward_pass_agent_keys[policy_id]
|
||||||
buffers = {k: self.agent_collectors[k].buffers for k in keys}
|
buffers = {k: self.agent_collectors[k].buffers for k in keys}
|
||||||
view_reqs = policy.model.inference_view_requirements if \
|
view_reqs = policy.model.inference_view_requirements if \
|
||||||
hasattr(policy, "model") else policy.view_requirements
|
getattr(policy, "model", None) else policy.view_requirements
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
for view_col, view_req in view_reqs.items():
|
for view_col, view_req in view_reqs.items():
|
||||||
|
@ -447,19 +448,19 @@ class _SimpleListCollector(_SampleCollector):
|
||||||
|
|
||||||
for agent_id, (_, pre_batch) in pre_batches.items():
|
for agent_id, (_, pre_batch) in pre_batches.items():
|
||||||
# Entire episode is said to be done.
|
# Entire episode is said to be done.
|
||||||
if is_done:
|
# Error if no DONE at end of this agent's trajectory.
|
||||||
# Error if no DONE at end of this agent's trajectory.
|
if is_done and check_dones and \
|
||||||
if check_dones and not pre_batch[SampleBatch.DONES][-1]:
|
not pre_batch[SampleBatch.DONES][-1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Episode {} terminated for all agents, but we still "
|
"Episode {} terminated for all agents, but we still don't "
|
||||||
"don't have a last observation for agent {} (policy "
|
"don't have a last observation for agent {} (policy "
|
||||||
"{}). ".format(
|
"{}). ".format(
|
||||||
episode_id, agent_id, self.agent_key_to_policy[(
|
episode_id, agent_id, self.agent_key_to_policy[(
|
||||||
episode_id, agent_id)]) +
|
episode_id, agent_id)]) +
|
||||||
"Please ensure that you include the last observations "
|
"Please ensure that you include the last observations "
|
||||||
"of all live agents when setting done[__all__] to "
|
"of all live agents when setting done[__all__] to "
|
||||||
"True. Alternatively, set no_done_at_end=True to "
|
"True. Alternatively, set no_done_at_end=True to "
|
||||||
"allow this.")
|
"allow this.")
|
||||||
# If (only this?) agent is done, erase its buffer entirely.
|
# If (only this?) agent is done, erase its buffer entirely.
|
||||||
if pre_batch[SampleBatch.DONES][-1]:
|
if pre_batch[SampleBatch.DONES][-1]:
|
||||||
del self.agent_collectors[(episode_id, agent_id)]
|
del self.agent_collectors[(episode_id, agent_id)]
|
||||||
|
|
|
@ -364,7 +364,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
self.clip_actions, self.multiple_episodes_in_batch, self.callbacks,
|
self.clip_actions, self.multiple_episodes_in_batch, self.callbacks,
|
||||||
self.tf_sess, self.perf_stats, self.soft_horizon,
|
self.tf_sess, self.perf_stats, self.soft_horizon,
|
||||||
self.no_done_at_end, self.observation_fn,
|
self.no_done_at_end, self.observation_fn,
|
||||||
self._use_trajectory_view_api)
|
self._use_trajectory_view_api, self.sample_collector)
|
||||||
while not self.shutdown:
|
while not self.shutdown:
|
||||||
# The timeout variable exists because apparently, if one worker
|
# The timeout variable exists because apparently, if one worker
|
||||||
# dies, the other workers won't die with it, unless the timeout is
|
# dies, the other workers won't die with it, unless the timeout is
|
||||||
|
@ -613,6 +613,7 @@ def _env_runner(
|
||||||
to_eval=to_eval,
|
to_eval=to_eval,
|
||||||
policies=policies,
|
policies=policies,
|
||||||
_sample_collector=_sample_collector,
|
_sample_collector=_sample_collector,
|
||||||
|
active_episodes=active_episodes,
|
||||||
tf_sess=tf_sess,
|
tf_sess=tf_sess,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -1252,7 +1253,8 @@ def _do_policy_eval_w_trajectory_view_api(
|
||||||
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
||||||
policies: Dict[PolicyID, Policy],
|
policies: Dict[PolicyID, Policy],
|
||||||
_sample_collector,
|
_sample_collector,
|
||||||
tf_sess=None,
|
active_episodes: Dict[str, MultiAgentEpisode],
|
||||||
|
tf_sess: Optional["tf.Session"] = None,
|
||||||
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
||||||
"""Call compute_actions on collected episode/model data to get next action.
|
"""Call compute_actions on collected episode/model data to get next action.
|
||||||
|
|
||||||
|
@ -1282,12 +1284,14 @@ def _do_policy_eval_w_trajectory_view_api(
|
||||||
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
||||||
summarize(to_eval)))
|
summarize(to_eval)))
|
||||||
|
|
||||||
for policy_id in to_eval.keys():
|
for policy_id, eval_data in to_eval.items():
|
||||||
policy: Policy = _get_or_raise(policies, policy_id)
|
policy: Policy = _get_or_raise(policies, policy_id)
|
||||||
input_dict = _sample_collector.get_inference_input_dict(policy_id)
|
input_dict = _sample_collector.get_inference_input_dict(policy_id)
|
||||||
eval_results[policy_id] = \
|
eval_results[policy_id] = \
|
||||||
policy.compute_actions_from_input_dict(
|
policy.compute_actions_from_input_dict(
|
||||||
input_dict, timestep=policy.global_timestep)
|
input_dict,
|
||||||
|
timestep=policy.global_timestep,
|
||||||
|
episodes=[active_episodes[t.env_id] for t in eval_data])
|
||||||
|
|
||||||
if builder:
|
if builder:
|
||||||
# type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
|
# type: PolicyID, Tuple[TensorStructType, StateBatch, dict]
|
||||||
|
|
|
@ -58,7 +58,6 @@ class Preprocessor:
|
||||||
observation = np.array(observation)
|
observation = np.array(observation)
|
||||||
try:
|
try:
|
||||||
if not self._obs_space.contains(observation):
|
if not self._obs_space.contains(observation):
|
||||||
print()
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Observation ({}) outside given space ({})!",
|
"Observation ({}) outside given space ({})!",
|
||||||
observation, self._obs_space)
|
observation, self._obs_space)
|
||||||
|
|
|
@ -2,10 +2,12 @@ from collections import OrderedDict
|
||||||
import gym
|
import gym
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
import re
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
from ray.util.debug import log_once
|
from ray.util.debug import log_once
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
|
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.policy.tf_policy import TFPolicy
|
from ray.rllib.policy.tf_policy import TFPolicy
|
||||||
|
@ -13,6 +15,7 @@ from ray.rllib.models.catalog import ModelCatalog
|
||||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||||
from ray.rllib.utils.debug import summarize
|
from ray.rllib.utils.debug import summarize
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
|
from ray.rllib.utils.tf_ops import get_placeholder
|
||||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||||
from ray.rllib.utils.typing import ModelGradients, TensorType, \
|
from ray.rllib.utils.typing import ModelGradients, TensorType, \
|
||||||
TrainerConfigDict
|
TrainerConfigDict
|
||||||
|
@ -53,8 +56,9 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
obs_space: gym.spaces.Space,
|
obs_space: gym.spaces.Space,
|
||||||
action_space: gym.spaces.Space,
|
action_space: gym.spaces.Space,
|
||||||
config: TrainerConfigDict,
|
config: TrainerConfigDict,
|
||||||
loss_fn: Callable[[Policy, ModelV2, type, SampleBatch],
|
loss_fn: Callable[[
|
||||||
TensorType],
|
Policy, ModelV2, Type[TFActionDistribution], SampleBatch
|
||||||
|
], TensorType],
|
||||||
*,
|
*,
|
||||||
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
|
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
|
||||||
str, TensorType]]] = None,
|
str, TensorType]]] = None,
|
||||||
|
@ -85,9 +89,9 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
policy.
|
policy.
|
||||||
action_space (gym.spaces.Space): Action space of the policy.
|
action_space (gym.spaces.Space): Action space of the policy.
|
||||||
config (TrainerConfigDict): Policy-specific configuration data.
|
config (TrainerConfigDict): Policy-specific configuration data.
|
||||||
loss_fn (Callable[[Policy, ModelV2, type, SampleBatch],
|
loss_fn (Callable[[Policy, ModelV2, Type[TFActionDistribution],
|
||||||
TensorType]): Function that returns a loss tensor for the
|
SampleBatch], TensorType]): Function that returns a loss tensor
|
||||||
policy graph.
|
for the policy graph.
|
||||||
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
||||||
Dict[str, TensorType]]]): Optional function that returns a dict
|
Dict[str, TensorType]]]): Optional function that returns a dict
|
||||||
of TF fetches given the policy and batch input tensors.
|
of TF fetches given the policy and batch input tensors.
|
||||||
|
@ -128,9 +132,9 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
placeholders to use instead of defining new ones.
|
placeholders to use instead of defining new ones.
|
||||||
existing_model (Optional[ModelV2]): When copying a policy, this
|
existing_model (Optional[ModelV2]): When copying a policy, this
|
||||||
specifies an existing model to clone and share weights with.
|
specifies an existing model to clone and share weights with.
|
||||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
|
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
|
||||||
Optional callable that returns the divisibility requirement
|
Optional callable that returns the divisibility requirement for
|
||||||
for sample batches given the Policy.
|
sample batches. If None, will assume a value of 1.
|
||||||
obs_include_prev_action_reward (bool): Whether to include the
|
obs_include_prev_action_reward (bool): Whether to include the
|
||||||
previous action and reward in the model input (default: True).
|
previous action and reward in the model input (default: True).
|
||||||
"""
|
"""
|
||||||
|
@ -262,10 +266,10 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
|
|
||||||
# Phase 1 init.
|
# Phase 1 init.
|
||||||
sess = tf1.get_default_session() or tf1.Session()
|
sess = tf1.get_default_session() or tf1.Session()
|
||||||
if get_batch_divisibility_req:
|
|
||||||
batch_divisibility_req = get_batch_divisibility_req(self)
|
batch_divisibility_req = get_batch_divisibility_req(self) if \
|
||||||
else:
|
callable(get_batch_divisibility_req) else \
|
||||||
batch_divisibility_req = 1
|
(get_batch_divisibility_req or 1)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space=obs_space,
|
observation_space=obs_space,
|
||||||
|
@ -353,6 +357,56 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def _get_input_dict_and_dummy_batch(self, view_requirements,
|
||||||
|
existing_inputs):
|
||||||
|
"""Creates input_dict and dummy_batch for loss initialization.
|
||||||
|
|
||||||
|
Used for managing the Policy's input placeholders and for loss
|
||||||
|
initialization.
|
||||||
|
Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
view_requirements (ViewReqs): The view requirements dict.
|
||||||
|
existing_inputs (Dict[str, tf.placeholder]): A dict of already
|
||||||
|
existing placeholders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
|
||||||
|
input_dict/dummy_batch tuple.
|
||||||
|
"""
|
||||||
|
input_dict = {}
|
||||||
|
dummy_batch = {}
|
||||||
|
for view_col, view_req in view_requirements.items():
|
||||||
|
# Point state_in to the already existing self._state_inputs.
|
||||||
|
mo = re.match("state_in_(\d+)", view_col)
|
||||||
|
if mo is not None:
|
||||||
|
input_dict[view_col] = self._state_inputs[int(mo.group(1))]
|
||||||
|
dummy_batch[view_col] = np.zeros_like(
|
||||||
|
[view_req.space.sample()])
|
||||||
|
# State-outs (no placeholders needed).
|
||||||
|
elif view_col.startswith("state_out_"):
|
||||||
|
dummy_batch[view_col] = np.zeros_like(
|
||||||
|
[view_req.space.sample()])
|
||||||
|
# Skip action dist inputs placeholder (do later).
|
||||||
|
elif view_col == SampleBatch.ACTION_DIST_INPUTS:
|
||||||
|
continue
|
||||||
|
elif view_col in existing_inputs:
|
||||||
|
input_dict[view_col] = existing_inputs[view_col]
|
||||||
|
dummy_batch[view_col] = np.zeros(
|
||||||
|
shape=[
|
||||||
|
1 if s is None else s
|
||||||
|
for s in existing_inputs[view_col].shape.as_list()
|
||||||
|
],
|
||||||
|
dtype=np.float32)
|
||||||
|
# All others.
|
||||||
|
else:
|
||||||
|
if view_req.used_for_training:
|
||||||
|
input_dict[view_col] = get_placeholder(
|
||||||
|
space=view_req.space)
|
||||||
|
dummy_batch[view_col] = np.zeros_like(
|
||||||
|
[view_req.space.sample()])
|
||||||
|
return input_dict, dummy_batch
|
||||||
|
|
||||||
def _initialize_loss_dynamically(self):
|
def _initialize_loss_dynamically(self):
|
||||||
def fake_array(tensor):
|
def fake_array(tensor):
|
||||||
shape = tensor.shape.as_list()
|
shape = tensor.shape.as_list()
|
||||||
|
|
|
@ -16,6 +16,8 @@ from ray.rllib.utils import add_mixins
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
|
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
|
||||||
|
from ray.rllib.utils.tf_ops import convert_to_non_tf_type
|
||||||
|
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -273,7 +275,7 @@ def build_eager_tf_policy(name,
|
||||||
if before_loss_init:
|
if before_loss_init:
|
||||||
before_loss_init(self, observation_space, action_space, config)
|
before_loss_init(self, observation_space, action_space, config)
|
||||||
|
|
||||||
self._initialize_loss_with_dummy_batch()
|
self._initialize_loss_from_dummy_batch()
|
||||||
self._loss_initialized = True
|
self._loss_initialized = True
|
||||||
|
|
||||||
if optimizer_fn:
|
if optimizer_fn:
|
||||||
|
@ -363,8 +365,8 @@ def build_eager_tf_policy(name,
|
||||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
||||||
"is_training": tf.constant(False),
|
"is_training": tf.constant(False),
|
||||||
}
|
}
|
||||||
n = input_dict[SampleBatch.CUR_OBS].shape[0]
|
batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
|
||||||
seq_lens = tf.ones(n, dtype=tf.int32)
|
seq_lens = tf.ones(batch_size, dtype=tf.int32)
|
||||||
if obs_include_prev_action_reward:
|
if obs_include_prev_action_reward:
|
||||||
if prev_action_batch is not None:
|
if prev_action_batch is not None:
|
||||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||||
|
@ -425,8 +427,7 @@ def build_eager_tf_policy(name,
|
||||||
extra_fetches.update(extra_action_fetches_fn(self))
|
extra_fetches.update(extra_action_fetches_fn(self))
|
||||||
|
|
||||||
# Update our global timestep by the batch size.
|
# Update our global timestep by the batch size.
|
||||||
self.global_timestep += len(obs_batch) if \
|
self.global_timestep += int(batch_size)
|
||||||
isinstance(obs_batch, (tuple, list)) else obs_batch.shape[0]
|
|
||||||
|
|
||||||
return actions, state_out, extra_fetches
|
return actions, state_out, extra_fetches
|
||||||
|
|
||||||
|
@ -636,7 +637,8 @@ def build_eager_tf_policy(name,
|
||||||
})
|
})
|
||||||
return fetches
|
return fetches
|
||||||
|
|
||||||
def _initialize_loss_with_dummy_batch(self):
|
@override(Policy)
|
||||||
|
def _initialize_loss_from_dummy_batch(self):
|
||||||
# Dummy forward pass to initialize any policy attributes, etc.
|
# Dummy forward pass to initialize any policy attributes, etc.
|
||||||
dummy_batch = {
|
dummy_batch = {
|
||||||
SampleBatch.CUR_OBS: np.array(
|
SampleBatch.CUR_OBS: np.array(
|
||||||
|
@ -711,6 +713,16 @@ def build_eager_tf_policy(name,
|
||||||
if stats_fn:
|
if stats_fn:
|
||||||
stats_fn(self, postprocessed_batch)
|
stats_fn(self, postprocessed_batch)
|
||||||
|
|
||||||
|
def _lazy_tensor_dict(self, postprocessed_batch):
|
||||||
|
train_batch = UsageTrackingDict(postprocessed_batch)
|
||||||
|
train_batch.set_get_interceptor(tf.convert_to_tensor)
|
||||||
|
return train_batch
|
||||||
|
|
||||||
|
def _lazy_numpy_dict(self, postprocessed_batch):
|
||||||
|
train_batch = UsageTrackingDict(postprocessed_batch)
|
||||||
|
train_batch.set_get_interceptor(convert_to_non_tf_type)
|
||||||
|
return train_batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def with_tracing(cls):
|
def with_tracing(cls):
|
||||||
return traced_eager_policy(cls)
|
return traced_eager_policy(cls)
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
import gym
|
import gym
|
||||||
|
from gym.spaces import Box
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tree
|
import tree
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
@ -227,11 +229,13 @@ class Policy(metaclass=ABCMeta):
|
||||||
return single_action, [s[0] for s in state_out], \
|
return single_action, [s[0] for s in state_out], \
|
||||||
{k: v[0] for k, v in info.items()}
|
{k: v[0] for k, v in info.items()}
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
def compute_actions_from_input_dict(
|
def compute_actions_from_input_dict(
|
||||||
self,
|
self,
|
||||||
input_dict: Dict[str, TensorType],
|
input_dict: Dict[str, TensorType],
|
||||||
explore: bool = None,
|
explore: bool = None,
|
||||||
timestep: Optional[int] = None,
|
timestep: Optional[int] = None,
|
||||||
|
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
||||||
**kwargs) -> \
|
**kwargs) -> \
|
||||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||||
"""Computes actions from collected samples (across multiple-agents).
|
"""Computes actions from collected samples (across multiple-agents).
|
||||||
|
@ -278,6 +282,7 @@ class Policy(metaclass=ABCMeta):
|
||||||
info_batch=None,
|
info_batch=None,
|
||||||
explore=explore,
|
explore=explore,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
|
episodes=episodes,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -534,6 +539,162 @@ class Policy(metaclass=ABCMeta):
|
||||||
framework=getattr(self, "framework", "tf"))
|
framework=getattr(self, "framework", "tf"))
|
||||||
return exploration
|
return exploration
|
||||||
|
|
||||||
|
def _get_default_view_requirements(self):
|
||||||
|
"""Returns a default ViewRequirements dict.
|
||||||
|
|
||||||
|
Note: This is the base/maximum requirement dict, from which later
|
||||||
|
some requirements will be subtracted again automatically to streamline
|
||||||
|
data collection, batch creation, and data transfer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ViewReqDict: The default view requirements dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default view requirements (equal to those that we would use before
|
||||||
|
# the trajectory view API was introduced).
|
||||||
|
return {
|
||||||
|
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
|
||||||
|
SampleBatch.NEXT_OBS: ViewRequirement(
|
||||||
|
data_col=SampleBatch.OBS,
|
||||||
|
shift=1,
|
||||||
|
space=self.observation_space),
|
||||||
|
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
||||||
|
SampleBatch.REWARDS: ViewRequirement(),
|
||||||
|
SampleBatch.DONES: ViewRequirement(),
|
||||||
|
SampleBatch.INFOS: ViewRequirement(),
|
||||||
|
SampleBatch.EPS_ID: ViewRequirement(),
|
||||||
|
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||||
|
"t": ViewRequirement(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _initialize_loss_from_dummy_batch(
|
||||||
|
self, auto_remove_unneeded_view_reqs: bool = True) -> None:
|
||||||
|
"""Performs test calls through policy's model and loss.
|
||||||
|
|
||||||
|
NOTE: This base method should work for define-by-run Policies such as
|
||||||
|
torch and tf-eager policies.
|
||||||
|
|
||||||
|
If required, will thereby detect automatically, which data views are
|
||||||
|
required by a) the forward pass, b) the postprocessing, and c) the loss
|
||||||
|
functions, and remove those from self.view_requirements that are not
|
||||||
|
necessary for these computations (to save data storage and transfer).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_remove_unneeded_view_reqs (bool): Whether to automatically
|
||||||
|
remove those ViewRequirements records from
|
||||||
|
self.view_requirements that are not needed.
|
||||||
|
"""
|
||||||
|
sample_batch_size = max(self.batch_divisibility_req, 2)
|
||||||
|
B = 2 # For RNNs, have B=2, T=[depends on sample_batch_size]
|
||||||
|
self._dummy_batch = self._get_dummy_batch_from_view_requirements(
|
||||||
|
sample_batch_size)
|
||||||
|
input_dict = self._lazy_tensor_dict(self._dummy_batch)
|
||||||
|
actions, state_outs, extra_outs = \
|
||||||
|
self.compute_actions_from_input_dict(input_dict)
|
||||||
|
# Add extra outs to view reqs.
|
||||||
|
for key, value in extra_outs.items():
|
||||||
|
self._dummy_batch[key] = np.zeros_like(value)
|
||||||
|
if key not in self.view_requirements:
|
||||||
|
self.view_requirements[key] = \
|
||||||
|
ViewRequirement(space=gym.spaces.Box(
|
||||||
|
-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype))
|
||||||
|
sb = SampleBatch(self._dummy_batch)
|
||||||
|
if state_outs:
|
||||||
|
# TODO: (sven) This hack will not work for attention net traj.
|
||||||
|
# view setup.
|
||||||
|
i = 0
|
||||||
|
while "state_in_{}".format(i) in sb:
|
||||||
|
sb["state_in_{}".format(i)] = sb["state_in_{}".format(i)][:B]
|
||||||
|
if "state_out_{}".format(i) in sb:
|
||||||
|
sb["state_out_{}".format(i)] = \
|
||||||
|
sb["state_out_{}".format(i)][:B]
|
||||||
|
i += 1
|
||||||
|
batch_for_postproc = self._lazy_numpy_dict(sb)
|
||||||
|
batch_for_postproc.count = sb.count
|
||||||
|
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
|
||||||
|
if state_outs:
|
||||||
|
seq_len = (self.batch_divisibility_req // B) or 1
|
||||||
|
postprocessed_batch["seq_lens"] = \
|
||||||
|
np.array([seq_len for _ in range(B)], dtype=np.int32)
|
||||||
|
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||||
|
if self._loss is not None:
|
||||||
|
self._loss(self, self.model, self.dist_class, train_batch)
|
||||||
|
|
||||||
|
# Add new columns automatically to view-reqs.
|
||||||
|
if self.config["_use_trajectory_view_api"] and \
|
||||||
|
auto_remove_unneeded_view_reqs:
|
||||||
|
# Add those needed for postprocessing and training.
|
||||||
|
all_accessed_keys = train_batch.accessed_keys | \
|
||||||
|
batch_for_postproc.accessed_keys | \
|
||||||
|
batch_for_postproc.added_keys
|
||||||
|
for key in all_accessed_keys:
|
||||||
|
if key not in self.view_requirements:
|
||||||
|
self.view_requirements[key] = ViewRequirement()
|
||||||
|
if self._loss:
|
||||||
|
# Tag those only needed for post-processing.
|
||||||
|
for key in batch_for_postproc.accessed_keys:
|
||||||
|
if key not in train_batch.accessed_keys:
|
||||||
|
self.view_requirements[key].used_for_training = False
|
||||||
|
# Remove those not needed at all (leave those that are needed
|
||||||
|
# by Sampler to properly execute sample collection).
|
||||||
|
for key in list(self.view_requirements.keys()):
|
||||||
|
if key not in all_accessed_keys and key not in [
|
||||||
|
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
||||||
|
SampleBatch.UNROLL_ID, SampleBatch.DONES] and \
|
||||||
|
key not in self.model.inference_view_requirements:
|
||||||
|
del self.view_requirements[key]
|
||||||
|
# Add those data_cols (again) that are missing and have
|
||||||
|
# dependencies by view_cols.
|
||||||
|
for key in list(self.view_requirements.keys()):
|
||||||
|
vr = self.view_requirements[key]
|
||||||
|
if vr.data_col is not None and \
|
||||||
|
vr.data_col not in self.view_requirements:
|
||||||
|
used_for_training = \
|
||||||
|
vr.data_col in train_batch.accessed_keys
|
||||||
|
self.view_requirements[vr.data_col] = \
|
||||||
|
ViewRequirement(
|
||||||
|
space=vr.space,
|
||||||
|
used_for_training=used_for_training)
|
||||||
|
|
||||||
|
def _get_dummy_batch_from_view_requirements(
|
||||||
|
self, batch_size: int = 1) -> SampleBatch:
|
||||||
|
"""Creates a numpy dummy batch based on the Policy's view requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): The size of the batch to create.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, TensorType]: The dummy batch containing all zero values.
|
||||||
|
"""
|
||||||
|
ret = {}
|
||||||
|
for view_col, view_req in self.view_requirements.items():
|
||||||
|
if isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
||||||
|
_, shape = ModelCatalog.get_action_shape(view_req.space)
|
||||||
|
ret[view_col] = \
|
||||||
|
np.zeros((batch_size, ) + shape[1:], np.float32)
|
||||||
|
else:
|
||||||
|
ret[view_col] = np.zeros_like(
|
||||||
|
[view_req.space.sample() for _ in range(batch_size)])
|
||||||
|
return SampleBatch(ret)
|
||||||
|
|
||||||
|
def _update_model_inference_view_requirements_from_init_state(self):
|
||||||
|
"""Uses this Model's initial state to auto-add necessary ViewReqs.
|
||||||
|
|
||||||
|
Can be called from within a Policy to make sure RNNs automatically
|
||||||
|
update their internal state-related view requirements.
|
||||||
|
Changes the `self.inference_view_requirements` dict.
|
||||||
|
"""
|
||||||
|
model = self.model
|
||||||
|
# Add state-ins to this model's view.
|
||||||
|
for i, state in enumerate(model.get_initial_state()):
|
||||||
|
model.inference_view_requirements["state_in_{}".format(i)] = \
|
||||||
|
ViewRequirement(
|
||||||
|
"state_out_{}".format(i),
|
||||||
|
shift=-1,
|
||||||
|
space=Box(-1.0, 1.0, shape=state.shape))
|
||||||
|
model.inference_view_requirements["state_out_{}".format(i)] = \
|
||||||
|
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))
|
||||||
|
|
||||||
|
|
||||||
def clip_action(action, action_space):
|
def clip_action(action, action_space):
|
||||||
"""Clips all actions in `flat_actions` according to the given Spaces.
|
"""Clips all actions in `flat_actions` according to the given Spaces.
|
||||||
|
|
|
@ -63,7 +63,7 @@ def build_tf_policy(
|
||||||
], Tuple[TensorType, type, List[TensorType]]]] = None,
|
], Tuple[TensorType, type, List[TensorType]]]] = None,
|
||||||
mixins: Optional[List[type]] = None,
|
mixins: Optional[List[type]] = None,
|
||||||
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
|
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
|
||||||
obs_include_prev_action_reward: bool = True) -> Type[TFPolicy]:
|
obs_include_prev_action_reward: bool = True) -> Type[DynamicTFPolicy]:
|
||||||
"""Helper function for creating a dynamic tf policy at runtime.
|
"""Helper function for creating a dynamic tf policy at runtime.
|
||||||
|
|
||||||
Functions will be run in this order to initialize the policy:
|
Functions will be run in this order to initialize the policy:
|
||||||
|
|
|
@ -591,6 +591,12 @@ class TorchPolicy(Policy):
|
||||||
functools.partial(convert_to_torch_tensor, device=self.device))
|
functools.partial(convert_to_torch_tensor, device=self.device))
|
||||||
return train_batch
|
return train_batch
|
||||||
|
|
||||||
|
def _lazy_numpy_dict(self, postprocessed_batch):
|
||||||
|
train_batch = UsageTrackingDict(postprocessed_batch)
|
||||||
|
train_batch.set_get_interceptor(
|
||||||
|
functools.partial(convert_to_non_torch_type))
|
||||||
|
return train_batch
|
||||||
|
|
||||||
|
|
||||||
# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch)
|
# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch)
|
||||||
# and for all possible hyperparams, not just lr.
|
# and for all possible hyperparams, not just lr.
|
||||||
|
|
|
@ -48,8 +48,14 @@ def build_torch_policy(
|
||||||
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
|
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
|
||||||
before_init: Optional[Callable[
|
before_init: Optional[Callable[
|
||||||
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
|
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
|
||||||
|
before_loss_init: Optional[Callable[[
|
||||||
|
Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
|
||||||
|
], None]] = None,
|
||||||
after_init: Optional[Callable[
|
after_init: Optional[Callable[
|
||||||
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
|
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
|
||||||
|
_after_loss_init: Optional[Callable[[
|
||||||
|
Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
|
||||||
|
], None]] = None,
|
||||||
action_sampler_fn: Optional[Callable[[TensorType, List[
|
action_sampler_fn: Optional[Callable[[TensorType, List[
|
||||||
TensorType]], Tuple[TensorType, TensorType]]] = None,
|
TensorType]], Tuple[TensorType, TensorType]]] = None,
|
||||||
action_distribution_fn: Optional[Callable[[
|
action_distribution_fn: Optional[Callable[[
|
||||||
|
@ -64,7 +70,7 @@ def build_torch_policy(
|
||||||
apply_gradients_fn: Optional[Callable[
|
apply_gradients_fn: Optional[Callable[
|
||||||
[Policy, "torch.optim.Optimizer"], None]] = None,
|
[Policy, "torch.optim.Optimizer"], None]] = None,
|
||||||
mixins: Optional[List[type]] = None,
|
mixins: Optional[List[type]] = None,
|
||||||
view_requirements_fn: Optional[Callable[[], Dict[
|
view_requirements_fn: Optional[Callable[[Policy], Dict[
|
||||||
str, ViewRequirement]]] = None,
|
str, ViewRequirement]]] = None,
|
||||||
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
|
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
|
||||||
) -> Type[TorchPolicy]:
|
) -> Type[TorchPolicy]:
|
||||||
|
@ -117,10 +123,17 @@ def build_torch_policy(
|
||||||
TrainerConfigDict], None]]): Optional callable to run at the
|
TrainerConfigDict], None]]): Optional callable to run at the
|
||||||
beginning of `Policy.__init__` that takes the same arguments as
|
beginning of `Policy.__init__` that takes the same arguments as
|
||||||
the Policy constructor. If None, this step will be skipped.
|
the Policy constructor. If None, this step will be skipped.
|
||||||
|
before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
|
||||||
|
gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
|
||||||
|
run prior to loss init. If None, this step will be skipped.
|
||||||
after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
||||||
TrainerConfigDict], None]]): Optional callable to run at the end of
|
TrainerConfigDict], None]]): DEPRECATED: Use `before_loss_init`
|
||||||
policy init that takes the same arguments as the policy
|
instead.
|
||||||
constructor. If None, this step will be skipped.
|
_after_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
|
||||||
|
gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
|
||||||
|
run after the loss init. If None, this step will be skipped.
|
||||||
|
This will be deprecated at some point and renamed into `after_init`
|
||||||
|
to match `build_tf_policy()` behavior.
|
||||||
action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
|
action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
|
||||||
Tuple[TensorType, TensorType]]]): Optional callable returning a
|
Tuple[TensorType, TensorType]]]): Optional callable returning a
|
||||||
sampled action and its log-likelihood given some (obs and state)
|
sampled action and its log-likelihood given some (obs and state)
|
||||||
|
@ -128,13 +141,13 @@ def build_torch_policy(
|
||||||
compute actions by calling self.model, then sampling from the
|
compute actions by calling self.model, then sampling from the
|
||||||
so parameterized action distribution.
|
so parameterized action distribution.
|
||||||
action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
|
action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
|
||||||
TensorType, TensorType], Tuple[TensorType, type,
|
TensorType, TensorType], Tuple[TensorType,
|
||||||
List[TensorType]]]]): A callable that takes
|
Type[TorchDistributionWrapper], List[TensorType]]]]): A callable
|
||||||
the Policy, Model, the observation batch, an explore-flag, a
|
that takes the Policy, Model, the observation batch, an
|
||||||
timestep, and an is_training flag and returns a tuple of
|
explore-flag, a timestep, and an is_training flag and returns a
|
||||||
a) distribution inputs (parameters), b) a dist-class to generate
|
tuple of a) distribution inputs (parameters), b) a dist-class to
|
||||||
an action distribution object from, and c) internal-state outputs
|
generate an action distribution object from, and c) internal-state
|
||||||
(empty list if not applicable). If None, will either use
|
outputs (empty list if not applicable). If None, will either use
|
||||||
`action_sampler_fn` or compute actions by calling self.model,
|
`action_sampler_fn` or compute actions by calling self.model,
|
||||||
then sampling from the parameterized action distribution.
|
then sampling from the parameterized action distribution.
|
||||||
make_model (Optional[Callable[[Policy, gym.spaces.Space,
|
make_model (Optional[Callable[[Policy, gym.spaces.Space,
|
||||||
|
|
|
@ -67,14 +67,15 @@ if __name__ == "__main__":
|
||||||
assert len(experiments) == 1,\
|
assert len(experiments) == 1,\
|
||||||
"Error, can only run a single experiment per yaml file!"
|
"Error, can only run a single experiment per yaml file!"
|
||||||
|
|
||||||
print("== Test config ==")
|
|
||||||
print(yaml.dump(experiments))
|
|
||||||
|
|
||||||
# Add torch option to exp configs.
|
# Add torch option to exp configs.
|
||||||
for exp in experiments.values():
|
for exp in experiments.values():
|
||||||
if args.torch:
|
if args.torch:
|
||||||
exp["config"]["framework"] = "torch"
|
exp["config"]["framework"] = "torch"
|
||||||
|
|
||||||
|
# Print out the actual config.
|
||||||
|
print("== Test config ==")
|
||||||
|
print(yaml.dump(experiments))
|
||||||
|
|
||||||
# Try running each test 3 times and make sure it reaches the given
|
# Try running each test 3 times and make sure it reaches the given
|
||||||
# reward.
|
# reward.
|
||||||
passed = False
|
passed = False
|
||||||
|
|
|
@ -35,10 +35,12 @@ def make_sample_batch(i):
|
||||||
|
|
||||||
class AgentIOTest(unittest.TestCase):
|
class AgentIOTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
ray.init(num_cpus=1, ignore_reinit_error=True)
|
||||||
self.test_dir = tempfile.mkdtemp()
|
self.test_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.test_dir)
|
shutil.rmtree(self.test_dir)
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
def writeOutputs(self, output, fw):
|
def writeOutputs(self, output, fw):
|
||||||
agent = PGTrainer(
|
agent = PGTrainer(
|
||||||
|
@ -225,7 +227,7 @@ class AgentIOTest(unittest.TestCase):
|
||||||
|
|
||||||
class JsonIOTest(unittest.TestCase):
|
class JsonIOTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
ray.init(num_cpus=1)
|
ray.init(num_cpus=1, ignore_reinit_error=True)
|
||||||
self.test_dir = tempfile.mkdtemp()
|
self.test_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
cartpole-dqn:
|
cartpole-simpleq:
|
||||||
env: CartPole-v0
|
env: CartPole-v0
|
||||||
run: SimpleQ
|
run: SimpleQ
|
||||||
stop:
|
stop:
|
||||||
|
|
|
@ -1,14 +1,61 @@
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import tree
|
||||||
|
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_non_tf_type(stats):
|
||||||
|
"""Converts values in `stats` to non-Tensor numpy or python types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stats (any): Any (possibly nested) struct, the values in which will be
|
||||||
|
converted and returned as a new struct with all tf (eager) tensors
|
||||||
|
being converted to numpy types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: A new struct with the same structure as `stats`, but with all
|
||||||
|
values converted to non-tf Tensor types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The mapping function used to numpyize torch Tensors.
|
||||||
|
def mapping(item):
|
||||||
|
if isinstance(item, (tf.Tensor, tf.Variable)):
|
||||||
|
return item.numpy()
|
||||||
|
else:
|
||||||
|
return item
|
||||||
|
|
||||||
|
return tree.map_structure(mapping, stats)
|
||||||
|
|
||||||
|
|
||||||
def explained_variance(y, pred):
|
def explained_variance(y, pred):
|
||||||
_, y_var = tf.nn.moments(y, axes=[0])
|
_, y_var = tf.nn.moments(y, axes=[0])
|
||||||
_, diff_var = tf.nn.moments(y - pred, axes=[0])
|
_, diff_var = tf.nn.moments(y - pred, axes=[0])
|
||||||
return tf.maximum(-1.0, 1 - (diff_var / y_var))
|
return tf.maximum(-1.0, 1 - (diff_var / y_var))
|
||||||
|
|
||||||
|
|
||||||
|
def get_placeholder(*, space=None, value=None):
|
||||||
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
|
|
||||||
|
if space is not None:
|
||||||
|
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
||||||
|
return ModelCatalog.get_action_placeholder(space, None)
|
||||||
|
return tf1.placeholder(
|
||||||
|
shape=(None, ) + space.shape,
|
||||||
|
dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert value is not None
|
||||||
|
shape = value.shape[1:]
|
||||||
|
return tf1.placeholder(
|
||||||
|
shape=(None, ) + (shape if isinstance(shape, tuple) else tuple(
|
||||||
|
shape.as_list())),
|
||||||
|
dtype=tf.float32 if value.dtype == np.float64 else value.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def huber_loss(x, delta=1.0):
|
def huber_loss(x, delta=1.0):
|
||||||
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
||||||
return tf.where(
|
return tf.where(
|
||||||
|
|
|
@ -125,11 +125,11 @@ def minimize_and_clip(optimizer, clip_val=10):
|
||||||
|
|
||||||
def one_hot(x, space):
|
def one_hot(x, space):
|
||||||
if isinstance(space, Discrete):
|
if isinstance(space, Discrete):
|
||||||
return nn.functional.one_hot(x, space.n)
|
return nn.functional.one_hot(x.long(), space.n)
|
||||||
elif isinstance(space, MultiDiscrete):
|
elif isinstance(space, MultiDiscrete):
|
||||||
return torch.cat(
|
return torch.cat(
|
||||||
[
|
[
|
||||||
nn.functional.one_hot(x[:, i], n)
|
nn.functional.one_hot(x[:, i].long(), n)
|
||||||
for i, n in enumerate(space.nvec)
|
for i, n in enumerate(space.nvec)
|
||||||
],
|
],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
|
|
@ -11,6 +11,7 @@ class UsageTrackingDict(dict):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
dict.__init__(self, *args, **kwargs)
|
dict.__init__(self, *args, **kwargs)
|
||||||
self.accessed_keys = set()
|
self.accessed_keys = set()
|
||||||
|
self.added_keys = set()
|
||||||
self.intercepted_values = {}
|
self.intercepted_values = {}
|
||||||
self.get_interceptor = None
|
self.get_interceptor = None
|
||||||
|
|
||||||
|
@ -32,6 +33,8 @@ class UsageTrackingDict(dict):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
|
if key not in self:
|
||||||
|
self.added_keys.add(key)
|
||||||
dict.__setitem__(self, key, value)
|
dict.__setitem__(self, key, value)
|
||||||
if key in self.intercepted_values:
|
if key in self.intercepted_values:
|
||||||
self.intercepted_values[key] = value
|
self.intercepted_values[key] = value
|
||||||
|
|
Loading…
Add table
Reference in a new issue