[RLlib] Fix bug in policy.py: normalize_actions=True has to call unsquash_action, not normalize_action. (#16774)

This commit is contained in:
Sven Mika 2021-07-08 17:31:34 +02:00 committed by GitHub
parent 9f6a92163b
commit 7862dd64ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 125 additions and 42 deletions

View file

@ -106,10 +106,21 @@ DEFAULT_CONFIG = with_common_config({
"prioritized_replay_eps": 1e-6,
# Whether to LZ4 compress observations
"compress_observations": False,
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).
# The intensity with which to update the model (vs collecting samples from
# the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into and
# sampled from the buffer matches the given value.
# Example:
# training_intensity=1000.0
# train_batch_size=250 rollout_fragment_length=1
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 250 / 1 = 250.0
# -> will make sure that replay+train op will be executed 4x as
# often as rollout+insert op (4 * 250 = 1000).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
"training_intensity": None,
# === Optimization ===

View file

@ -103,10 +103,21 @@ DEFAULT_CONFIG = with_common_config({
"compress_observations": False,
# Callback to run before learning on a multi-agent batch of experiences.
"before_learn_on_batch": None,
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).
# The intensity with which to update the model (vs collecting samples from
# the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into and
# sampled from the buffer matches the given value.
# Example:
# training_intensity=800.0
# train_batch_size=32 rollout_fragment_length=4
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 32 / 4 = 8.0
# -> will make sure that replay+train op will be executed 100x as
# often as rollout+insert op (100 * 8.0 = 800.0).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
"training_intensity": None,
# === Optimization ===

View file

@ -102,10 +102,21 @@ DEFAULT_CONFIG = with_common_config({
"final_prioritized_replay_beta": 0.4,
# Whether to LZ4 compress observations
"compress_observations": False,
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).
# The intensity with which to update the model (vs collecting samples from
# the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into and
# sampled from the buffer matches the given value.
# Example:
# training_intensity=1000.0
# train_batch_size=250 rollout_fragment_length=1
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 250 / 1 = 250.0
# -> will make sure that replay+train op will be executed 4x as
# often as rollout+insert op (4 * 250 = 1000).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
"training_intensity": None,
# === Optimization ===

View file

@ -939,7 +939,7 @@ class Trainer(Trainable):
policy_id: PolicyID = DEFAULT_POLICY_ID,
full_fetch: bool = False,
explore: bool = None,
normalize_actions: Optional[bool] = None,
unsquash_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
) -> TensorStructType:
"""Computes an action for the specified policy on the local Worker.
@ -964,7 +964,7 @@ class Trainer(Trainable):
This is always set to True if RNN state is specified.
explore (bool): Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
normalize_actions (bool): Should actions be normalized according to
unsquash_actions (bool): Should actions be unsquashed according to
the env's/Policy's action space?
clip_actions (bool): Should actions be clipped according to the
env's/Policy's action space?
@ -989,7 +989,7 @@ class Trainer(Trainable):
prev_action,
prev_reward,
info,
normalize_actions=normalize_actions,
unsquash_actions=unsquash_actions,
clip_actions=clip_actions,
explore=explore)

View file

@ -69,7 +69,9 @@ class OffPolicyEstimator:
obs_batch=batch[SampleBatch.CUR_OBS],
state_batches=[batch[k] for k in state_keys],
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS))
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
actions_normalized=True,
)
log_likelihoods = convert_to_numpy(log_likelihoods)
return np.exp(log_likelihoods)

View file

@ -17,6 +17,7 @@ from ray.rllib.utils import add_mixins, force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.typing import TensorType
@ -565,12 +566,15 @@ def build_eager_tf_policy(
@with_lock
@override(Policy)
def compute_log_likelihoods(self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None):
def compute_log_likelihoods(
self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
actions_normalized=True,
):
if action_sampler_fn and action_distribution_fn is None:
raise ValueError("Cannot compute log-prob/likelihood w/o an "
"`action_distribution_fn` and a provided "
@ -606,6 +610,11 @@ def build_eager_tf_policy(
dist_class = self.dist_class
action_dist = dist_class(dist_inputs, self.model)
# Normalize actions if necessary.
if not actions_normalized and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)
log_likelihoods = action_dist.logp(actions)
return log_likelihoods

View file

@ -14,7 +14,7 @@ from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.spaces.space_utils import clip_action, \
get_base_struct_from_space, normalize_action, unbatch
get_base_struct_from_space, unbatch, unsquash_action
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
TensorType, TrainerConfigDict, Tuple, Union
@ -161,7 +161,7 @@ class Policy(metaclass=ABCMeta):
clip_actions: bool = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
normalize_actions: bool = None,
unsquash_actions: bool = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Unbatched version of compute_actions.
@ -176,7 +176,7 @@ class Policy(metaclass=ABCMeta):
episode (Optional[MultiAgentEpisode]): this provides access to all
of the internal episode state, which may be useful for
model-based or multi-agent algorithms.
normalize_actions (bool): Should actions be normalized according to
unsquash_actions (bool): Should actions be unsquashed according to
the Policy's action space?
clip_actions (bool): Should actions be clipped according to the
Policy's action space?
@ -195,8 +195,10 @@ class Policy(metaclass=ABCMeta):
if any.
- info (dict): Dictionary of extra features, if any.
"""
normalize_actions = \
normalize_actions if normalize_actions is not None \
# If policy works in normalized space, we should unsquash the action.
# Use value of config.normalize_actions, if None.
unsquash_actions = \
unsquash_actions if unsquash_actions is not None \
else self.config["normalize_actions"]
clip_actions = clip_actions if clip_actions is not None else \
self.config["clip_actions"]
@ -244,9 +246,12 @@ class Policy(metaclass=ABCMeta):
assert len(single_action) == 1
single_action = single_action[0]
if normalize_actions:
single_action = normalize_action(single_action,
self.action_space_struct)
# If we work in normalized action space (normalize_actions=True),
# we re-translate here into the env's action space.
if unsquash_actions:
single_action = unsquash_action(single_action,
self.action_space_struct)
# Clip, according to env's action space.
elif clip_actions:
single_action = clip_action(single_action,
self.action_space_struct)
@ -314,8 +319,10 @@ class Policy(metaclass=ABCMeta):
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[
TensorType], TensorType]] = None) -> TensorType:
prev_reward_batch: Optional[Union[List[TensorType],
TensorType]] = None,
actions_normalized: bool = True,
) -> TensorType:
"""Computes the log-prob/likelihood for a given action and observation.
Args:
@ -330,6 +337,10 @@ class Policy(metaclass=ABCMeta):
Batch of previous action values.
prev_reward_batch (Optional[Union[List[TensorType], TensorType]]):
Batch of previous rewards.
actions_normalized (bool): Is the given `actions` already
normalized (between -1.0 and 1.0) or not? If not and
`normalize_actions=True`, we need to normalize the given
actions first, before calculating log likelihoods.
Returns:
TensorType: Batch of log probs/likelihoods, with shape:

View file

@ -54,7 +54,11 @@ def do_test_log_likelihood(run,
obs_batch[0],
prev_action=prev_a,
prev_reward=prev_r,
explore=True))
explore=True,
# Do not unsquash actions
# (remain in normalized [-1.0; 1.0] space).
unsquash_actions=False,
))
# Test all taken actions for their log-likelihoods vs expected values.
if continuous:
@ -89,7 +93,9 @@ def do_test_log_likelihood(run,
np.array([a]),
preprocessed_obs_batch,
prev_action_batch=np.array([prev_a]) if prev_a else None,
prev_reward_batch=np.array([prev_r]) if prev_r else None)
prev_reward_batch=np.array([prev_r]) if prev_r else None,
actions_normalized=True,
)
check(logp, expected_logp[0], rtol=0.2)
# Test all available actions for their logp values.
else:

View file

@ -17,6 +17,7 @@ from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf, get_variable
from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.typing import ModelGradients, TensorType, \
TrainerConfigDict
@ -399,8 +400,10 @@ class TFPolicy(Policy):
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[
TensorType], TensorType]] = None) -> TensorType:
prev_reward_batch: Optional[Union[List[TensorType],
TensorType]] = None,
actions_normalized: bool = True,
) -> TensorType:
if self._log_likelihood is None:
raise ValueError("Cannot compute log-prob/likelihood w/o a "
@ -411,6 +414,11 @@ class TFPolicy(Policy):
explore=False, tf_sess=self.get_session())
builder = TFRunBuilder(self._sess, "compute_log_likelihoods")
# Normalize actions if necessary.
if actions_normalized is False and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)
# Feed actions (for which we want logp values) into graph.
builder.add_feed_dict({self._action_input: actions})
# Feed observations.

View file

@ -22,6 +22,7 @@ from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor
@ -373,8 +374,10 @@ class TorchPolicy(Policy):
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[
TensorType], TensorType]] = None) -> TensorType:
prev_reward_batch: Optional[Union[List[TensorType],
TensorType]] = None,
actions_normalized: bool = True,
) -> TensorType:
if self.action_sampler_fn and self.action_distribution_fn is None:
raise ValueError("Cannot compute log-prob/likelihood w/o an "
@ -436,7 +439,13 @@ class TorchPolicy(Policy):
seq_lens)
action_dist = dist_class(dist_inputs, self.model)
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
# Normalize actions if necessary.
actions = input_dict[SampleBatch.ACTIONS]
if not actions_normalized and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)
log_likelihoods = action_dist.logp(actions)
return log_likelihoods

View file

@ -156,9 +156,12 @@ def clip_action(action, action_space):
def unsquash_action(action, action_space_struct):
"""Unsquashes all components in `action` according to the given Space.
Inverse of `normalize_action()`. Useful for mapping policy action
outputs (normalized between -1.0 and 1.0) to an env's action space.
Unsquashing results in cont. action component values between the
given Space's bounds (`low` and `high`). This only applies to Box
components within the action space.
components within the action space, whose dtype is float32 or float64.
Args:
action (Any): The action to be unsquashed. This could be any complex
@ -189,8 +192,10 @@ def unsquash_action(action, action_space_struct):
def normalize_action(action, action_space_struct):
"""Normalizes all (Box) components in `action` to be in [-1.0, 1.0].
This only applies to Box components, whose dtype is float32 or float64,
within the action space.
Inverse of `unsquash_action()`. Useful for mapping an env's action
(arbitrary bounded values) to a [-1.0, 1.0] interval.
This only applies to Box components within the action space, whose
dtype is float32 or float64.
Args:
action (Any): The action to be normalized. This could be any complex