mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Fix bug in policy.py: normalize_actions=True has to call unsquash_action
, not normalize_action
. (#16774)
This commit is contained in:
parent
9f6a92163b
commit
7862dd64ea
11 changed files with 125 additions and 42 deletions
|
@ -106,10 +106,21 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"prioritized_replay_eps": 1e-6,
|
||||
# Whether to LZ4 compress observations
|
||||
"compress_observations": False,
|
||||
# If set, this will fix the ratio of replayed from a buffer and learned on
|
||||
# timesteps to sampled from an environment and stored in the replay buffer
|
||||
# timesteps. Otherwise, the replay will proceed at the native ratio
|
||||
# determined by (train_batch_size / rollout_fragment_length).
|
||||
|
||||
# The intensity with which to update the model (vs collecting samples from
|
||||
# the env). If None, uses the "natural" value of:
|
||||
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
|
||||
# `num_envs_per_worker`).
|
||||
# If provided, will make sure that the ratio between ts inserted into and
|
||||
# sampled from the buffer matches the given value.
|
||||
# Example:
|
||||
# training_intensity=1000.0
|
||||
# train_batch_size=250 rollout_fragment_length=1
|
||||
# num_workers=1 (or 0) num_envs_per_worker=1
|
||||
# -> natural value = 250 / 1 = 250.0
|
||||
# -> will make sure that replay+train op will be executed 4x as
|
||||
# often as rollout+insert op (4 * 250 = 1000).
|
||||
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
|
||||
"training_intensity": None,
|
||||
|
||||
# === Optimization ===
|
||||
|
|
|
@ -103,10 +103,21 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"compress_observations": False,
|
||||
# Callback to run before learning on a multi-agent batch of experiences.
|
||||
"before_learn_on_batch": None,
|
||||
# If set, this will fix the ratio of replayed from a buffer and learned on
|
||||
# timesteps to sampled from an environment and stored in the replay buffer
|
||||
# timesteps. Otherwise, the replay will proceed at the native ratio
|
||||
# determined by (train_batch_size / rollout_fragment_length).
|
||||
|
||||
# The intensity with which to update the model (vs collecting samples from
|
||||
# the env). If None, uses the "natural" value of:
|
||||
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
|
||||
# `num_envs_per_worker`).
|
||||
# If provided, will make sure that the ratio between ts inserted into and
|
||||
# sampled from the buffer matches the given value.
|
||||
# Example:
|
||||
# training_intensity=800.0
|
||||
# train_batch_size=32 rollout_fragment_length=4
|
||||
# num_workers=1 (or 0) num_envs_per_worker=1
|
||||
# -> natural value = 32 / 4 = 8.0
|
||||
# -> will make sure that replay+train op will be executed 100x as
|
||||
# often as rollout+insert op (100 * 8.0 = 800.0).
|
||||
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
|
||||
"training_intensity": None,
|
||||
|
||||
# === Optimization ===
|
||||
|
|
|
@ -102,10 +102,21 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"final_prioritized_replay_beta": 0.4,
|
||||
# Whether to LZ4 compress observations
|
||||
"compress_observations": False,
|
||||
# If set, this will fix the ratio of replayed from a buffer and learned on
|
||||
# timesteps to sampled from an environment and stored in the replay buffer
|
||||
# timesteps. Otherwise, the replay will proceed at the native ratio
|
||||
# determined by (train_batch_size / rollout_fragment_length).
|
||||
|
||||
# The intensity with which to update the model (vs collecting samples from
|
||||
# the env). If None, uses the "natural" value of:
|
||||
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
|
||||
# `num_envs_per_worker`).
|
||||
# If provided, will make sure that the ratio between ts inserted into and
|
||||
# sampled from the buffer matches the given value.
|
||||
# Example:
|
||||
# training_intensity=1000.0
|
||||
# train_batch_size=250 rollout_fragment_length=1
|
||||
# num_workers=1 (or 0) num_envs_per_worker=1
|
||||
# -> natural value = 250 / 1 = 250.0
|
||||
# -> will make sure that replay+train op will be executed 4x as
|
||||
# often as rollout+insert op (4 * 250 = 1000).
|
||||
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
|
||||
"training_intensity": None,
|
||||
|
||||
# === Optimization ===
|
||||
|
|
|
@ -939,7 +939,7 @@ class Trainer(Trainable):
|
|||
policy_id: PolicyID = DEFAULT_POLICY_ID,
|
||||
full_fetch: bool = False,
|
||||
explore: bool = None,
|
||||
normalize_actions: Optional[bool] = None,
|
||||
unsquash_actions: Optional[bool] = None,
|
||||
clip_actions: Optional[bool] = None,
|
||||
) -> TensorStructType:
|
||||
"""Computes an action for the specified policy on the local Worker.
|
||||
|
@ -964,7 +964,7 @@ class Trainer(Trainable):
|
|||
This is always set to True if RNN state is specified.
|
||||
explore (bool): Whether to pick an exploitation or exploration
|
||||
action (default: None -> use self.config["explore"]).
|
||||
normalize_actions (bool): Should actions be normalized according to
|
||||
unsquash_actions (bool): Should actions be unsquashed according to
|
||||
the env's/Policy's action space?
|
||||
clip_actions (bool): Should actions be clipped according to the
|
||||
env's/Policy's action space?
|
||||
|
@ -989,7 +989,7 @@ class Trainer(Trainable):
|
|||
prev_action,
|
||||
prev_reward,
|
||||
info,
|
||||
normalize_actions=normalize_actions,
|
||||
unsquash_actions=unsquash_actions,
|
||||
clip_actions=clip_actions,
|
||||
explore=explore)
|
||||
|
||||
|
|
|
@ -69,7 +69,9 @@ class OffPolicyEstimator:
|
|||
obs_batch=batch[SampleBatch.CUR_OBS],
|
||||
state_batches=[batch[k] for k in state_keys],
|
||||
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
|
||||
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS))
|
||||
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
|
||||
actions_normalized=True,
|
||||
)
|
||||
log_likelihoods = convert_to_numpy(log_likelihoods)
|
||||
return np.exp(log_likelihoods)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from ray.rllib.utils import add_mixins, force_list
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.spaces.space_utils import normalize_action
|
||||
from ray.rllib.utils.threading import with_lock
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
|
@ -565,12 +566,15 @@ def build_eager_tf_policy(
|
|||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
def compute_log_likelihoods(self,
|
||||
actions,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None):
|
||||
def compute_log_likelihoods(
|
||||
self,
|
||||
actions,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
actions_normalized=True,
|
||||
):
|
||||
if action_sampler_fn and action_distribution_fn is None:
|
||||
raise ValueError("Cannot compute log-prob/likelihood w/o an "
|
||||
"`action_distribution_fn` and a provided "
|
||||
|
@ -606,6 +610,11 @@ def build_eager_tf_policy(
|
|||
dist_class = self.dist_class
|
||||
|
||||
action_dist = dist_class(dist_inputs, self.model)
|
||||
|
||||
# Normalize actions if necessary.
|
||||
if not actions_normalized and self.config["normalize_actions"]:
|
||||
actions = normalize_action(actions, self.action_space_struct)
|
||||
|
||||
log_likelihoods = action_dist.logp(actions)
|
||||
|
||||
return log_likelihoods
|
||||
|
|
|
@ -14,7 +14,7 @@ from ray.rllib.utils.exploration.exploration import Exploration
|
|||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.spaces.space_utils import clip_action, \
|
||||
get_base_struct_from_space, normalize_action, unbatch
|
||||
get_base_struct_from_space, unbatch, unsquash_action
|
||||
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
|
||||
TensorType, TrainerConfigDict, Tuple, Union
|
||||
|
||||
|
@ -161,7 +161,7 @@ class Policy(metaclass=ABCMeta):
|
|||
clip_actions: bool = None,
|
||||
explore: Optional[bool] = None,
|
||||
timestep: Optional[int] = None,
|
||||
normalize_actions: bool = None,
|
||||
unsquash_actions: bool = None,
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
"""Unbatched version of compute_actions.
|
||||
|
@ -176,7 +176,7 @@ class Policy(metaclass=ABCMeta):
|
|||
episode (Optional[MultiAgentEpisode]): this provides access to all
|
||||
of the internal episode state, which may be useful for
|
||||
model-based or multi-agent algorithms.
|
||||
normalize_actions (bool): Should actions be normalized according to
|
||||
unsquash_actions (bool): Should actions be unsquashed according to
|
||||
the Policy's action space?
|
||||
clip_actions (bool): Should actions be clipped according to the
|
||||
Policy's action space?
|
||||
|
@ -195,8 +195,10 @@ class Policy(metaclass=ABCMeta):
|
|||
if any.
|
||||
- info (dict): Dictionary of extra features, if any.
|
||||
"""
|
||||
normalize_actions = \
|
||||
normalize_actions if normalize_actions is not None \
|
||||
# If policy works in normalized space, we should unsquash the action.
|
||||
# Use value of config.normalize_actions, if None.
|
||||
unsquash_actions = \
|
||||
unsquash_actions if unsquash_actions is not None \
|
||||
else self.config["normalize_actions"]
|
||||
clip_actions = clip_actions if clip_actions is not None else \
|
||||
self.config["clip_actions"]
|
||||
|
@ -244,9 +246,12 @@ class Policy(metaclass=ABCMeta):
|
|||
assert len(single_action) == 1
|
||||
single_action = single_action[0]
|
||||
|
||||
if normalize_actions:
|
||||
single_action = normalize_action(single_action,
|
||||
self.action_space_struct)
|
||||
# If we work in normalized action space (normalize_actions=True),
|
||||
# we re-translate here into the env's action space.
|
||||
if unsquash_actions:
|
||||
single_action = unsquash_action(single_action,
|
||||
self.action_space_struct)
|
||||
# Clip, according to env's action space.
|
||||
elif clip_actions:
|
||||
single_action = clip_action(single_action,
|
||||
self.action_space_struct)
|
||||
|
@ -314,8 +319,10 @@ class Policy(metaclass=ABCMeta):
|
|||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Optional[Union[List[TensorType],
|
||||
TensorType]] = None,
|
||||
prev_reward_batch: Optional[Union[List[
|
||||
TensorType], TensorType]] = None) -> TensorType:
|
||||
prev_reward_batch: Optional[Union[List[TensorType],
|
||||
TensorType]] = None,
|
||||
actions_normalized: bool = True,
|
||||
) -> TensorType:
|
||||
"""Computes the log-prob/likelihood for a given action and observation.
|
||||
|
||||
Args:
|
||||
|
@ -330,6 +337,10 @@ class Policy(metaclass=ABCMeta):
|
|||
Batch of previous action values.
|
||||
prev_reward_batch (Optional[Union[List[TensorType], TensorType]]):
|
||||
Batch of previous rewards.
|
||||
actions_normalized (bool): Is the given `actions` already
|
||||
normalized (between -1.0 and 1.0) or not? If not and
|
||||
`normalize_actions=True`, we need to normalize the given
|
||||
actions first, before calculating log likelihoods.
|
||||
|
||||
Returns:
|
||||
TensorType: Batch of log probs/likelihoods, with shape:
|
||||
|
|
|
@ -54,7 +54,11 @@ def do_test_log_likelihood(run,
|
|||
obs_batch[0],
|
||||
prev_action=prev_a,
|
||||
prev_reward=prev_r,
|
||||
explore=True))
|
||||
explore=True,
|
||||
# Do not unsquash actions
|
||||
# (remain in normalized [-1.0; 1.0] space).
|
||||
unsquash_actions=False,
|
||||
))
|
||||
|
||||
# Test all taken actions for their log-likelihoods vs expected values.
|
||||
if continuous:
|
||||
|
@ -89,7 +93,9 @@ def do_test_log_likelihood(run,
|
|||
np.array([a]),
|
||||
preprocessed_obs_batch,
|
||||
prev_action_batch=np.array([prev_a]) if prev_a else None,
|
||||
prev_reward_batch=np.array([prev_r]) if prev_r else None)
|
||||
prev_reward_batch=np.array([prev_r]) if prev_r else None,
|
||||
actions_normalized=True,
|
||||
)
|
||||
check(logp, expected_logp[0], rtol=0.2)
|
||||
# Test all available actions for their logp values.
|
||||
else:
|
||||
|
|
|
@ -17,6 +17,7 @@ from ray.rllib.utils.debug import summarize
|
|||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import try_import_tf, get_variable
|
||||
from ray.rllib.utils.schedules import PiecewiseSchedule
|
||||
from ray.rllib.utils.spaces.space_utils import normalize_action
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils.typing import ModelGradients, TensorType, \
|
||||
TrainerConfigDict
|
||||
|
@ -399,8 +400,10 @@ class TFPolicy(Policy):
|
|||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Optional[Union[List[TensorType],
|
||||
TensorType]] = None,
|
||||
prev_reward_batch: Optional[Union[List[
|
||||
TensorType], TensorType]] = None) -> TensorType:
|
||||
prev_reward_batch: Optional[Union[List[TensorType],
|
||||
TensorType]] = None,
|
||||
actions_normalized: bool = True,
|
||||
) -> TensorType:
|
||||
|
||||
if self._log_likelihood is None:
|
||||
raise ValueError("Cannot compute log-prob/likelihood w/o a "
|
||||
|
@ -411,6 +414,11 @@ class TFPolicy(Policy):
|
|||
explore=False, tf_sess=self.get_session())
|
||||
|
||||
builder = TFRunBuilder(self._sess, "compute_log_likelihoods")
|
||||
|
||||
# Normalize actions if necessary.
|
||||
if actions_normalized is False and self.config["normalize_actions"]:
|
||||
actions = normalize_action(actions, self.action_space_struct)
|
||||
|
||||
# Feed actions (for which we want logp values) into graph.
|
||||
builder.add_feed_dict({self._action_input: actions})
|
||||
# Feed observations.
|
||||
|
|
|
@ -22,6 +22,7 @@ from ray.rllib.utils.annotations import override, DeveloperAPI
|
|||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.schedules import PiecewiseSchedule
|
||||
from ray.rllib.utils.spaces.space_utils import normalize_action
|
||||
from ray.rllib.utils.threading import with_lock
|
||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
|
||||
convert_to_torch_tensor
|
||||
|
@ -373,8 +374,10 @@ class TorchPolicy(Policy):
|
|||
state_batches: Optional[List[TensorType]] = None,
|
||||
prev_action_batch: Optional[Union[List[TensorType],
|
||||
TensorType]] = None,
|
||||
prev_reward_batch: Optional[Union[List[
|
||||
TensorType], TensorType]] = None) -> TensorType:
|
||||
prev_reward_batch: Optional[Union[List[TensorType],
|
||||
TensorType]] = None,
|
||||
actions_normalized: bool = True,
|
||||
) -> TensorType:
|
||||
|
||||
if self.action_sampler_fn and self.action_distribution_fn is None:
|
||||
raise ValueError("Cannot compute log-prob/likelihood w/o an "
|
||||
|
@ -436,7 +439,13 @@ class TorchPolicy(Policy):
|
|||
seq_lens)
|
||||
|
||||
action_dist = dist_class(dist_inputs, self.model)
|
||||
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
|
||||
|
||||
# Normalize actions if necessary.
|
||||
actions = input_dict[SampleBatch.ACTIONS]
|
||||
if not actions_normalized and self.config["normalize_actions"]:
|
||||
actions = normalize_action(actions, self.action_space_struct)
|
||||
|
||||
log_likelihoods = action_dist.logp(actions)
|
||||
|
||||
return log_likelihoods
|
||||
|
||||
|
|
|
@ -156,9 +156,12 @@ def clip_action(action, action_space):
|
|||
def unsquash_action(action, action_space_struct):
|
||||
"""Unsquashes all components in `action` according to the given Space.
|
||||
|
||||
Inverse of `normalize_action()`. Useful for mapping policy action
|
||||
outputs (normalized between -1.0 and 1.0) to an env's action space.
|
||||
Unsquashing results in cont. action component values between the
|
||||
given Space's bounds (`low` and `high`). This only applies to Box
|
||||
components within the action space.
|
||||
components within the action space, whose dtype is float32 or float64.
|
||||
|
||||
|
||||
Args:
|
||||
action (Any): The action to be unsquashed. This could be any complex
|
||||
|
@ -189,8 +192,10 @@ def unsquash_action(action, action_space_struct):
|
|||
def normalize_action(action, action_space_struct):
|
||||
"""Normalizes all (Box) components in `action` to be in [-1.0, 1.0].
|
||||
|
||||
This only applies to Box components, whose dtype is float32 or float64,
|
||||
within the action space.
|
||||
Inverse of `unsquash_action()`. Useful for mapping an env's action
|
||||
(arbitrary bounded values) to a [-1.0, 1.0] interval.
|
||||
This only applies to Box components within the action space, whose
|
||||
dtype is float32 or float64.
|
||||
|
||||
Args:
|
||||
action (Any): The action to be normalized. This could be any complex
|
||||
|
|
Loading…
Add table
Reference in a new issue