Revert "[RLlib] Fix bug in policy.py: normalize_actions=True has to call unsquash_action, not normalize_action." (#17002)

This reverts commit 7862dd64ea.
This commit is contained in:
Amog Kamsetty 2021-07-12 11:09:14 -07:00 committed by GitHub
parent 24e00fcb1b
commit bc33dc7e96
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 42 additions and 125 deletions

View file

@ -106,21 +106,10 @@ DEFAULT_CONFIG = with_common_config({
"prioritized_replay_eps": 1e-6,
# Whether to LZ4 compress observations
"compress_observations": False,
# The intensity with which to update the model (vs collecting samples from
# the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into and
# sampled from the buffer matches the given value.
# Example:
# training_intensity=1000.0
# train_batch_size=250 rollout_fragment_length=1
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 250 / 1 = 250.0
# -> will make sure that replay+train op will be executed 4x as
# often as rollout+insert op (4 * 250 = 1000).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# === Optimization ===

View file

@ -103,21 +103,10 @@ DEFAULT_CONFIG = with_common_config({
"compress_observations": False,
# Callback to run before learning on a multi-agent batch of experiences.
"before_learn_on_batch": None,
# The intensity with which to update the model (vs collecting samples from
# the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into and
# sampled from the buffer matches the given value.
# Example:
# training_intensity=800.0
# train_batch_size=32 rollout_fragment_length=4
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 32 / 4 = 8.0
# -> will make sure that replay+train op will be executed 100x as
# often as rollout+insert op (100 * 8.0 = 800.0).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# === Optimization ===

View file

@ -102,21 +102,10 @@ DEFAULT_CONFIG = with_common_config({
"final_prioritized_replay_beta": 0.4,
# Whether to LZ4 compress observations
"compress_observations": False,
# The intensity with which to update the model (vs collecting samples from
# the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into and
# sampled from the buffer matches the given value.
# Example:
# training_intensity=1000.0
# train_batch_size=250 rollout_fragment_length=1
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 250 / 1 = 250.0
# -> will make sure that replay+train op will be executed 4x as
# often as rollout+insert op (4 * 250 = 1000).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# === Optimization ===

View file

@ -943,7 +943,7 @@ class Trainer(Trainable):
policy_id: PolicyID = DEFAULT_POLICY_ID,
full_fetch: bool = False,
explore: bool = None,
unsquash_actions: Optional[bool] = None,
normalize_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
) -> TensorStructType:
"""Computes an action for the specified policy on the local Worker.
@ -968,7 +968,7 @@ class Trainer(Trainable):
This is always set to True if RNN state is specified.
explore (bool): Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
unsquash_actions (bool): Should actions be unsquashed according to
normalize_actions (bool): Should actions be normalized according to
the env's/Policy's action space?
clip_actions (bool): Should actions be clipped according to the
env's/Policy's action space?
@ -993,7 +993,7 @@ class Trainer(Trainable):
prev_action,
prev_reward,
info,
unsquash_actions=unsquash_actions,
normalize_actions=normalize_actions,
clip_actions=clip_actions,
explore=explore)

View file

@ -69,9 +69,7 @@ class OffPolicyEstimator:
obs_batch=batch[SampleBatch.CUR_OBS],
state_batches=[batch[k] for k in state_keys],
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
actions_normalized=True,
)
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS))
log_likelihoods = convert_to_numpy(log_likelihoods)
return np.exp(log_likelihoods)

View file

@ -17,7 +17,6 @@ from ray.rllib.utils import add_mixins, force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.typing import TensorType
@ -566,15 +565,12 @@ def build_eager_tf_policy(
@with_lock
@override(Policy)
def compute_log_likelihoods(
self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
actions_normalized=True,
):
def compute_log_likelihoods(self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None):
if action_sampler_fn and action_distribution_fn is None:
raise ValueError("Cannot compute log-prob/likelihood w/o an "
"`action_distribution_fn` and a provided "
@ -610,11 +606,6 @@ def build_eager_tf_policy(
dist_class = self.dist_class
action_dist = dist_class(dist_inputs, self.model)
# Normalize actions if necessary.
if not actions_normalized and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)
log_likelihoods = action_dist.logp(actions)
return log_likelihoods

View file

@ -14,7 +14,7 @@ from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.spaces.space_utils import clip_action, \
get_base_struct_from_space, unbatch, unsquash_action
get_base_struct_from_space, normalize_action, unbatch
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
TensorType, TrainerConfigDict, Tuple, Union
@ -161,7 +161,7 @@ class Policy(metaclass=ABCMeta):
clip_actions: bool = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
unsquash_actions: bool = None,
normalize_actions: bool = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Unbatched version of compute_actions.
@ -176,7 +176,7 @@ class Policy(metaclass=ABCMeta):
episode (Optional[MultiAgentEpisode]): this provides access to all
of the internal episode state, which may be useful for
model-based or multi-agent algorithms.
unsquash_actions (bool): Should actions be unsquashed according to
normalize_actions (bool): Should actions be normalized according to
the Policy's action space?
clip_actions (bool): Should actions be clipped according to the
Policy's action space?
@ -195,10 +195,8 @@ class Policy(metaclass=ABCMeta):
if any.
- info (dict): Dictionary of extra features, if any.
"""
# If policy works in normalized space, we should unsquash the action.
# Use value of config.normalize_actions, if None.
unsquash_actions = \
unsquash_actions if unsquash_actions is not None \
normalize_actions = \
normalize_actions if normalize_actions is not None \
else self.config["normalize_actions"]
clip_actions = clip_actions if clip_actions is not None else \
self.config["clip_actions"]
@ -246,12 +244,9 @@ class Policy(metaclass=ABCMeta):
assert len(single_action) == 1
single_action = single_action[0]
# If we work in normalized action space (normalize_actions=True),
# we re-translate here into the env's action space.
if unsquash_actions:
single_action = unsquash_action(single_action,
self.action_space_struct)
# Clip, according to env's action space.
if normalize_actions:
single_action = normalize_action(single_action,
self.action_space_struct)
elif clip_actions:
single_action = clip_action(single_action,
self.action_space_struct)
@ -319,10 +314,8 @@ class Policy(metaclass=ABCMeta):
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[TensorType],
TensorType]] = None,
actions_normalized: bool = True,
) -> TensorType:
prev_reward_batch: Optional[Union[List[
TensorType], TensorType]] = None) -> TensorType:
"""Computes the log-prob/likelihood for a given action and observation.
Args:
@ -337,10 +330,6 @@ class Policy(metaclass=ABCMeta):
Batch of previous action values.
prev_reward_batch (Optional[Union[List[TensorType], TensorType]]):
Batch of previous rewards.
actions_normalized (bool): Is the given `actions` already
normalized (between -1.0 and 1.0) or not? If not and
`normalize_actions=True`, we need to normalize the given
actions first, before calculating log likelihoods.
Returns:
TensorType: Batch of log probs/likelihoods, with shape:

View file

@ -54,11 +54,7 @@ def do_test_log_likelihood(run,
obs_batch[0],
prev_action=prev_a,
prev_reward=prev_r,
explore=True,
# Do not unsquash actions
# (remain in normalized [-1.0; 1.0] space).
unsquash_actions=False,
))
explore=True))
# Test all taken actions for their log-likelihoods vs expected values.
if continuous:
@ -93,9 +89,7 @@ def do_test_log_likelihood(run,
np.array([a]),
preprocessed_obs_batch,
prev_action_batch=np.array([prev_a]) if prev_a else None,
prev_reward_batch=np.array([prev_r]) if prev_r else None,
actions_normalized=True,
)
prev_reward_batch=np.array([prev_r]) if prev_r else None)
check(logp, expected_logp[0], rtol=0.2)
# Test all available actions for their logp values.
else:

View file

@ -17,7 +17,6 @@ from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf, get_variable
from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.typing import ModelGradients, TensorType, \
TrainerConfigDict
@ -400,10 +399,8 @@ class TFPolicy(Policy):
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[TensorType],
TensorType]] = None,
actions_normalized: bool = True,
) -> TensorType:
prev_reward_batch: Optional[Union[List[
TensorType], TensorType]] = None) -> TensorType:
if self._log_likelihood is None:
raise ValueError("Cannot compute log-prob/likelihood w/o a "
@ -414,11 +411,6 @@ class TFPolicy(Policy):
explore=False, tf_sess=self.get_session())
builder = TFRunBuilder(self._sess, "compute_log_likelihoods")
# Normalize actions if necessary.
if actions_normalized is False and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)
# Feed actions (for which we want logp values) into graph.
builder.add_feed_dict({self._action_input: actions})
# Feed observations.

View file

@ -22,7 +22,6 @@ from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor
@ -374,10 +373,8 @@ class TorchPolicy(Policy):
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[TensorType],
TensorType]] = None,
actions_normalized: bool = True,
) -> TensorType:
prev_reward_batch: Optional[Union[List[
TensorType], TensorType]] = None) -> TensorType:
if self.action_sampler_fn and self.action_distribution_fn is None:
raise ValueError("Cannot compute log-prob/likelihood w/o an "
@ -439,13 +436,7 @@ class TorchPolicy(Policy):
seq_lens)
action_dist = dist_class(dist_inputs, self.model)
# Normalize actions if necessary.
actions = input_dict[SampleBatch.ACTIONS]
if not actions_normalized and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)
log_likelihoods = action_dist.logp(actions)
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
return log_likelihoods

View file

@ -156,12 +156,9 @@ def clip_action(action, action_space):
def unsquash_action(action, action_space_struct):
"""Unsquashes all components in `action` according to the given Space.
Inverse of `normalize_action()`. Useful for mapping policy action
outputs (normalized between -1.0 and 1.0) to an env's action space.
Unsquashing results in cont. action component values between the
given Space's bounds (`low` and `high`). This only applies to Box
components within the action space, whose dtype is float32 or float64.
components within the action space.
Args:
action (Any): The action to be unsquashed. This could be any complex
@ -192,10 +189,8 @@ def unsquash_action(action, action_space_struct):
def normalize_action(action, action_space_struct):
"""Normalizes all (Box) components in `action` to be in [-1.0, 1.0].
Inverse of `unsquash_action()`. Useful for mapping an env's action
(arbitrary bounded values) to a [-1.0, 1.0] interval.
This only applies to Box components within the action space, whose
dtype is float32 or float64.
This only applies to Box components, whose dtype is float32 or float64,
within the action space.
Args:
action (Any): The action to be normalized. This could be any complex