ray/rllib/utils/exploration/parameter_noise.py

422 lines
17 KiB
Python
Raw Normal View History

from gym.spaces import Box, Discrete
2020-03-29 00:16:30 +01:00
import numpy as np
from typing import Optional, TYPE_CHECKING, Union
2020-03-29 00:16:30 +01:00
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.models.action_dist import ActionDistribution
2020-03-29 00:16:30 +01:00
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDeterministic
from ray.rllib.policy.sample_batch import SampleBatch
2020-03-29 00:16:30 +01:00
from ray.rllib.utils.annotations import override
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import get_variable, try_import_tf, \
try_import_torch
2020-03-29 00:16:30 +01:00
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.numpy import softmax, SMALL_NUMBER
from ray.rllib.utils.typing import TensorType
if TYPE_CHECKING:
from ray.rllib.policy.policy import Policy
2020-03-29 00:16:30 +01:00
tf1, tf, tfv = try_import_tf()
2020-03-29 00:16:30 +01:00
torch, _ = try_import_torch()
class ParameterNoise(Exploration):
"""An exploration that changes a Model's parameters.
Implemented based on:
[1] https://blog.openai.com/better-exploration-with-parameter-noise/
[2] https://arxiv.org/pdf/1706.01905.pdf
At the beginning of an episode, Gaussian noise is added to all weights
of the model. At the end of the episode, the noise is undone and an action
diff (pi-delta) is calculated, from which we determine the changes in the
noise's stddev for the next episode.
"""
def __init__(self,
action_space,
*,
framework: str,
policy_config: dict,
model: ModelV2,
initial_stddev: float = 1.0,
random_timesteps: int = 10000,
sub_exploration: Optional[dict] = None,
2020-03-29 00:16:30 +01:00
**kwargs):
"""Initializes a ParameterNoise Exploration object.
Args:
initial_stddev (float): The initial stddev to use for the noise.
random_timesteps (int): The number of timesteps to act completely
randomly (see [1]).
sub_exploration (Optional[dict]): Optional sub-exploration config.
None for auto-detection/setup.
"""
assert framework is not None
super().__init__(
action_space,
policy_config=policy_config,
model=model,
framework=framework,
**kwargs)
2020-03-29 00:16:30 +01:00
self.stddev = get_variable(
initial_stddev, framework=self.framework, tf_name="stddev")
self.stddev_val = initial_stddev # Out-of-graph tf value holder.
# The weight variables of the Model where noise should be applied to.
# This excludes any variable, whose name contains "LayerNorm" (those
# are BatchNormalization layers, which should not be perturbed).
self.model_variables = [
v for k, v in self.model.trainable_variables(as_dict=True).items()
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
if "LayerNorm" not in k
2020-03-29 00:16:30 +01:00
]
# Our noise to be added to the weights. Each item in `self.noise`
# corresponds to one Model variable and holding the Gaussian noise to
# be added to that variable (weight).
self.noise = []
for var in self.model_variables:
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
name_ = var.name.split(":")[0] + "_noisy" if var.name else ""
2020-03-29 00:16:30 +01:00
self.noise.append(
get_variable(
np.zeros(var.shape, dtype=np.float32),
framework=self.framework,
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
tf_name=name_,
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
torch_tensor=True,
device=self.device))
2020-03-29 00:16:30 +01:00
# tf-specific ops to sample, assign and remove noise.
if self.framework == "tf" and not tf.executing_eagerly():
self.tf_sample_new_noise_op = \
self._tf_sample_new_noise_op()
self.tf_add_stored_noise_op = \
self._tf_add_stored_noise_op()
self.tf_remove_noise_op = \
self._tf_remove_noise_op()
# Create convenience sample+add op for tf.
with tf1.control_dependencies([self.tf_sample_new_noise_op]):
2020-03-29 00:16:30 +01:00
add_op = self._tf_add_stored_noise_op()
with tf1.control_dependencies([add_op]):
2020-03-29 00:16:30 +01:00
self.tf_sample_new_noise_and_add_op = tf.no_op()
# Whether the Model's weights currently have noise added or not.
self.weights_are_currently_noisy = False
# Auto-detection of underlying exploration functionality.
if sub_exploration is None:
# For discrete action spaces, use an underlying EpsilonGreedy with
# a special schedule.
if isinstance(self.action_space, Discrete):
sub_exploration = {
"type": "EpsilonGreedy",
"epsilon_schedule": {
"type": "PiecewiseSchedule",
# Step function (see [2]).
"endpoints": [(0, 1.0), (random_timesteps + 1, 1.0),
(random_timesteps + 2, 0.01)],
"outside_value": 0.01
}
}
elif isinstance(self.action_space, Box):
sub_exploration = {
"type": "OrnsteinUhlenbeckNoise",
"random_timesteps": random_timesteps,
}
2020-03-29 00:16:30 +01:00
# TODO(sven): Implement for any action space.
else:
raise NotImplementedError
self.sub_exploration = from_config(
Exploration,
sub_exploration,
framework=self.framework,
action_space=self.action_space,
policy_config=self.policy_config,
model=self.model,
2020-03-29 00:16:30 +01:00
**kwargs)
# Whether we need to call `self._delayed_on_episode_start` before
# the forward pass.
self.episode_started = False
@override(Exploration)
def before_compute_actions(self,
*,
timestep: Optional[int] = None,
explore: bool = None,
tf_sess: Optional["tf.Session"] = None):
explore = explore if explore is not None else \
self.policy_config["explore"]
2020-03-29 00:16:30 +01:00
# Is this the first forward pass in the new episode? If yes, do the
# noise re-sampling and add to weights.
if self.episode_started:
self._delayed_on_episode_start(explore, tf_sess)
2020-03-29 00:16:30 +01:00
# Add noise if necessary.
if explore and not self.weights_are_currently_noisy:
self._add_stored_noise(tf_sess=tf_sess)
# Remove noise if necessary.
elif not explore and self.weights_are_currently_noisy:
self._remove_noise(tf_sess=tf_sess)
@override(Exploration)
def get_exploration_action(self, *,
action_distribution: ActionDistribution,
timestep: Union[TensorType, int],
explore: Union[TensorType, bool]):
2020-03-29 00:16:30 +01:00
# Use our sub-exploration object to handle the final exploration
# action (depends on the algo-type/action-space/etc..).
return self.sub_exploration.get_exploration_action(
action_distribution=action_distribution,
2020-03-29 00:16:30 +01:00
timestep=timestep,
explore=explore)
@override(Exploration)
def on_episode_start(self,
policy: "Policy",
2020-03-29 00:16:30 +01:00
*,
environment: BaseEnv = None,
episode: int = None,
tf_sess: Optional["tf.Session"] = None):
2020-03-29 00:16:30 +01:00
# We have to delay the noise-adding step by one forward call.
# This is due to the fact that the optimizer does it's step right
# after the episode was reset (and hence the noise was already added!).
# We don't want to update into a noisy net.
self.episode_started = True
def _delayed_on_episode_start(self, explore, tf_sess):
2020-03-29 00:16:30 +01:00
# Sample fresh noise and add to weights.
if explore:
2020-03-29 00:16:30 +01:00
self._sample_new_noise_and_add(tf_sess=tf_sess, override=True)
# Only sample, don't apply anything to the weights.
else:
self._sample_new_noise(tf_sess=tf_sess)
self.episode_started = False
@override(Exploration)
def on_episode_end(self,
policy,
*,
environment=None,
episode=None,
tf_sess=None):
# Remove stored noise from weights (only if currently noisy).
if self.weights_are_currently_noisy:
self._remove_noise(tf_sess=tf_sess)
@override(Exploration)
def postprocess_trajectory(self,
policy: "Policy",
sample_batch: SampleBatch,
tf_sess: Optional["tf.Session"] = None):
2020-03-29 00:16:30 +01:00
noisy_action_dist = noise_free_action_dist = None
# Adjust the stddev depending on the action (pi)-distance.
# Also see [1] for details.
# TODO(sven): Find out whether this can be scrapped by simply using
# the `sample_batch` to get the noisy/noise-free action dist.
_, _, fetches = policy.compute_actions(
2020-03-29 00:16:30 +01:00
obs_batch=sample_batch[SampleBatch.CUR_OBS],
# TODO(sven): What about state-ins and seq-lens?
prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS),
explore=self.weights_are_currently_noisy)
# Categorical case (e.g. DQN).
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
if policy.dist_class in (Categorical, TorchCategorical):
action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS])
# Deterministic (Gaussian actions, e.g. DDPG).
elif policy.dist_class in [Deterministic, TorchDeterministic]:
action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS]
else:
raise NotImplementedError # TODO(sven): Other action-dist cases.
2020-03-29 00:16:30 +01:00
if self.weights_are_currently_noisy:
noisy_action_dist = action_dist
else:
noise_free_action_dist = action_dist
_, _, fetches = policy.compute_actions(
2020-03-29 00:16:30 +01:00
obs_batch=sample_batch[SampleBatch.CUR_OBS],
prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS),
explore=not self.weights_are_currently_noisy)
# Categorical case (e.g. DQN).
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
if policy.dist_class in (Categorical, TorchCategorical):
action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS])
# Deterministic (Gaussian actions, e.g. DDPG).
elif policy.dist_class in [Deterministic, TorchDeterministic]:
action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS]
2020-03-29 00:16:30 +01:00
if noisy_action_dist is None:
2020-03-29 00:16:30 +01:00
noisy_action_dist = action_dist
else:
noise_free_action_dist = action_dist
delta = distance = None
2020-03-29 00:16:30 +01:00
# Categorical case (e.g. DQN).
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
if policy.dist_class in (Categorical, TorchCategorical):
2020-03-29 00:16:30 +01:00
# Calculate KL-divergence (DKL(clean||noisy)) according to [2].
# TODO(sven): Allow KL-divergence to be calculated by our
# Distribution classes (don't support off-graph/numpy yet).
distance = np.nanmean(
2020-03-29 00:16:30 +01:00
np.sum(
noise_free_action_dist *
np.log(noise_free_action_dist /
(noisy_action_dist + SMALL_NUMBER)), 1))
current_epsilon = self.sub_exploration.get_info(
sess=tf_sess)["cur_epsilon"]
2020-03-29 00:16:30 +01:00
delta = -np.log(1 - current_epsilon +
current_epsilon / self.action_space.n)
elif policy.dist_class in [Deterministic, TorchDeterministic]:
# Calculate MSE between noisy and non-noisy output (see [2]).
distance = np.sqrt(
np.mean(np.square(noise_free_action_dist - noisy_action_dist)))
current_scale = self.sub_exploration.get_info(
sess=tf_sess)["cur_scale"]
delta = getattr(self.sub_exploration, "ou_sigma", 0.2) * \
current_scale
# Adjust stddev according to the calculated action-distance.
if distance <= delta:
self.stddev_val *= 1.01
else:
self.stddev_val /= 1.01
2020-03-29 00:16:30 +01:00
# Set self.stddev to calculated value.
if self.framework == "tf":
self.stddev.load(self.stddev_val, session=tf_sess)
else:
self.stddev = self.stddev_val
return sample_batch
def _sample_new_noise(self, *, tf_sess=None):
"""Samples new noise and stores it in `self.noise`."""
if self.framework == "tf":
tf_sess.run(self.tf_sample_new_noise_op)
elif self.framework in ["tfe", "tf2"]:
self._tf_sample_new_noise_op()
2020-03-29 00:16:30 +01:00
else:
for i in range(len(self.noise)):
self.noise[i] = torch.normal(
mean=torch.zeros(self.noise[i].size()),
std=self.stddev).to(self.device)
2020-03-29 00:16:30 +01:00
def _tf_sample_new_noise_op(self):
added_noises = []
for noise in self.noise:
added_noises.append(
tf1.assign(
2020-03-29 00:16:30 +01:00
noise,
tf.random.normal(
2020-03-29 00:16:30 +01:00
shape=noise.shape,
stddev=self.stddev,
dtype=tf.float32)))
return tf.group(*added_noises)
def _sample_new_noise_and_add(self, *, tf_sess=None, override=False):
if self.framework == "tf":
2020-03-29 00:16:30 +01:00
if override and self.weights_are_currently_noisy:
tf_sess.run(self.tf_remove_noise_op)
tf_sess.run(self.tf_sample_new_noise_and_add_op)
else:
if override and self.weights_are_currently_noisy:
self._remove_noise()
self._sample_new_noise()
self._add_stored_noise()
self.weights_are_currently_noisy = True
def _add_stored_noise(self, *, tf_sess=None):
"""Adds the stored `self.noise` to the model's parameters.
Note: No new sampling of noise here.
Args:
tf_sess (Optional[tf.Session]): The tf-session to use to add the
stored noise to the (currently noise-free) weights.
override (bool): If True, undo any currently applied noise first,
then add the currently stored noise.
"""
# Make sure we only add noise to currently noise-free weights.
assert self.weights_are_currently_noisy is False
# Add stored noise to the model's parameters.
if self.framework == "tf":
tf_sess.run(self.tf_add_stored_noise_op)
elif self.framework in ["tf2", "tfe"]:
self._tf_add_stored_noise_op()
2020-03-29 00:16:30 +01:00
else:
for var, noise in zip(self.model_variables, self.noise):
2020-03-29 00:16:30 +01:00
# Add noise to weights in-place.
var.requires_grad = False
var.add_(noise)
var.requires_grad = True
2020-03-29 00:16:30 +01:00
self.weights_are_currently_noisy = True
def _tf_add_stored_noise_op(self):
"""Generates tf-op that assigns the stored noise to weights.
Also used by tf-eager.
Returns:
tf.op: The tf op to apply the already stored noise to the NN.
"""
add_noise_ops = list()
for var, noise in zip(self.model_variables, self.noise):
add_noise_ops.append(tf1.assign_add(var, noise))
2020-03-29 00:16:30 +01:00
ret = tf.group(*tuple(add_noise_ops))
with tf1.control_dependencies([ret]):
2020-03-29 00:16:30 +01:00
return tf.no_op()
def _remove_noise(self, *, tf_sess=None):
"""
Removes the current action noise from the model parameters.
Args:
tf_sess (Optional[tf.Session]): The tf-session to use to remove
the noise from the (currently noisy) weights.
"""
# Make sure we only remove noise iff currently noisy.
assert self.weights_are_currently_noisy is True
# Removes the stored noise from the model's parameters.
2020-03-29 00:16:30 +01:00
if self.framework == "tf":
tf_sess.run(self.tf_remove_noise_op)
elif self.framework in ["tf2", "tfe"]:
self._tf_remove_noise_op()
2020-03-29 00:16:30 +01:00
else:
for var, noise in zip(self.model_variables, self.noise):
# Remove noise from weights in-place.
var.requires_grad = False
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
var.add_(-noise)
var.requires_grad = True
2020-03-29 00:16:30 +01:00
self.weights_are_currently_noisy = False
def _tf_remove_noise_op(self):
"""Generates a tf-op for removing noise from the model's weights.
Also used by tf-eager.
Returns:
tf.op: The tf op to remve the currently stored noise from the NN.
"""
remove_noise_ops = list()
for var, noise in zip(self.model_variables, self.noise):
remove_noise_ops.append(tf1.assign_add(var, -noise))
2020-03-29 00:16:30 +01:00
ret = tf.group(*tuple(remove_noise_ops))
with tf1.control_dependencies([ret]):
2020-03-29 00:16:30 +01:00
return tf.no_op()
@override(Exploration)
def get_info(self, sess=None):
return {"cur_stddev": self.stddev_val}