2020-04-09 23:04:21 +02:00
|
|
|
from gym.spaces import Box, Discrete
|
2020-03-29 00:16:30 +01:00
|
|
|
import numpy as np
|
2020-10-07 21:59:14 +02:00
|
|
|
from typing import Optional, TYPE_CHECKING, Union
|
2020-03-29 00:16:30 +01:00
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
from ray.rllib.utils.annotations import PublicAPI
|
2020-10-07 21:59:14 +02:00
|
|
|
from ray.rllib.env.base_env import BaseEnv
|
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
2020-03-29 00:16:30 +01:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2020-04-09 23:04:21 +02:00
|
|
|
from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic
|
|
|
|
from ray.rllib.models.torch.torch_action_dist import (
|
|
|
|
TorchCategorical,
|
|
|
|
TorchDeterministic,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-10-07 21:59:14 +02:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2020-03-29 00:16:30 +01:00
|
|
|
from ray.rllib.utils.annotations import override
|
|
|
|
from ray.rllib.utils.exploration.exploration import Exploration
|
2020-10-07 21:59:14 +02:00
|
|
|
from ray.rllib.utils.framework import get_variable, try_import_tf, try_import_torch
|
2020-03-29 00:16:30 +01:00
|
|
|
from ray.rllib.utils.from_config import from_config
|
|
|
|
from ray.rllib.utils.numpy import softmax, SMALL_NUMBER
|
2020-10-07 21:59:14 +02:00
|
|
|
from ray.rllib.utils.typing import TensorType
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from ray.rllib.policy.policy import Policy
|
2020-03-29 00:16:30 +01:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-03-29 00:16:30 +01:00
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@PublicAPI
|
2020-03-29 00:16:30 +01:00
|
|
|
class ParameterNoise(Exploration):
|
|
|
|
"""An exploration that changes a Model's parameters.
|
|
|
|
|
|
|
|
Implemented based on:
|
|
|
|
[1] https://blog.openai.com/better-exploration-with-parameter-noise/
|
|
|
|
[2] https://arxiv.org/pdf/1706.01905.pdf
|
|
|
|
|
|
|
|
At the beginning of an episode, Gaussian noise is added to all weights
|
|
|
|
of the model. At the end of the episode, the noise is undone and an action
|
|
|
|
diff (pi-delta) is calculated, from which we determine the changes in the
|
|
|
|
noise's stddev for the next episode.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
action_space,
|
|
|
|
*,
|
|
|
|
framework: str,
|
|
|
|
policy_config: dict,
|
|
|
|
model: ModelV2,
|
2020-10-07 21:59:14 +02:00
|
|
|
initial_stddev: float = 1.0,
|
|
|
|
random_timesteps: int = 10000,
|
|
|
|
sub_exploration: Optional[dict] = None,
|
2020-03-29 00:16:30 +01:00
|
|
|
**kwargs
|
|
|
|
):
|
|
|
|
"""Initializes a ParameterNoise Exploration object.
|
|
|
|
|
|
|
|
Args:
|
2021-12-15 22:32:52 +01:00
|
|
|
initial_stddev: The initial stddev to use for the noise.
|
|
|
|
random_timesteps: The number of timesteps to act completely
|
2020-03-29 00:16:30 +01:00
|
|
|
randomly (see [1]).
|
2021-12-15 22:32:52 +01:00
|
|
|
sub_exploration: Optional sub-exploration config.
|
2020-03-29 00:16:30 +01:00
|
|
|
None for auto-detection/setup.
|
|
|
|
"""
|
|
|
|
assert framework is not None
|
2020-04-01 09:43:21 +02:00
|
|
|
super().__init__(
|
|
|
|
action_space,
|
|
|
|
policy_config=policy_config,
|
|
|
|
model=model,
|
|
|
|
framework=framework,
|
|
|
|
**kwargs
|
|
|
|
)
|
2020-03-29 00:16:30 +01:00
|
|
|
|
|
|
|
self.stddev = get_variable(
|
|
|
|
initial_stddev, framework=self.framework, tf_name="stddev"
|
|
|
|
)
|
|
|
|
self.stddev_val = initial_stddev # Out-of-graph tf value holder.
|
|
|
|
|
|
|
|
# The weight variables of the Model where noise should be applied to.
|
|
|
|
# This excludes any variable, whose name contains "LayerNorm" (those
|
|
|
|
# are BatchNormalization layers, which should not be perturbed).
|
|
|
|
self.model_variables = [
|
2020-10-06 20:28:16 +02:00
|
|
|
v
|
|
|
|
for k, v in self.model.trainable_variables(as_dict=True).items()
|
2020-04-06 20:56:16 +02:00
|
|
|
if "LayerNorm" not in k
|
2020-03-29 00:16:30 +01:00
|
|
|
]
|
|
|
|
# Our noise to be added to the weights. Each item in `self.noise`
|
|
|
|
# corresponds to one Model variable and holding the Gaussian noise to
|
|
|
|
# be added to that variable (weight).
|
|
|
|
self.noise = []
|
|
|
|
for var in self.model_variables:
|
2020-04-06 20:56:16 +02:00
|
|
|
name_ = var.name.split(":")[0] + "_noisy" if var.name else ""
|
2020-03-29 00:16:30 +01:00
|
|
|
self.noise.append(
|
|
|
|
get_variable(
|
|
|
|
np.zeros(var.shape, dtype=np.float32),
|
|
|
|
framework=self.framework,
|
2020-04-06 20:56:16 +02:00
|
|
|
tf_name=name_,
|
2020-04-15 13:25:16 +02:00
|
|
|
torch_tensor=True,
|
|
|
|
device=self.device,
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-29 00:16:30 +01:00
|
|
|
|
|
|
|
# tf-specific ops to sample, assign and remove noise.
|
|
|
|
if self.framework == "tf" and not tf.executing_eagerly():
|
|
|
|
self.tf_sample_new_noise_op = self._tf_sample_new_noise_op()
|
|
|
|
self.tf_add_stored_noise_op = self._tf_add_stored_noise_op()
|
|
|
|
self.tf_remove_noise_op = self._tf_remove_noise_op()
|
|
|
|
# Create convenience sample+add op for tf.
|
2020-06-30 10:13:20 +02:00
|
|
|
with tf1.control_dependencies([self.tf_sample_new_noise_op]):
|
2020-03-29 00:16:30 +01:00
|
|
|
add_op = self._tf_add_stored_noise_op()
|
2020-06-30 10:13:20 +02:00
|
|
|
with tf1.control_dependencies([add_op]):
|
2020-03-29 00:16:30 +01:00
|
|
|
self.tf_sample_new_noise_and_add_op = tf.no_op()
|
|
|
|
|
|
|
|
# Whether the Model's weights currently have noise added or not.
|
|
|
|
self.weights_are_currently_noisy = False
|
|
|
|
|
|
|
|
# Auto-detection of underlying exploration functionality.
|
|
|
|
if sub_exploration is None:
|
|
|
|
# For discrete action spaces, use an underlying EpsilonGreedy with
|
|
|
|
# a special schedule.
|
|
|
|
if isinstance(self.action_space, Discrete):
|
|
|
|
sub_exploration = {
|
|
|
|
"type": "EpsilonGreedy",
|
|
|
|
"epsilon_schedule": {
|
|
|
|
"type": "PiecewiseSchedule",
|
|
|
|
# Step function (see [2]).
|
|
|
|
"endpoints": [
|
|
|
|
(0, 1.0),
|
|
|
|
(random_timesteps + 1, 1.0),
|
|
|
|
(random_timesteps + 2, 0.01),
|
|
|
|
],
|
|
|
|
"outside_value": 0.01,
|
|
|
|
},
|
|
|
|
}
|
2020-04-09 23:04:21 +02:00
|
|
|
elif isinstance(self.action_space, Box):
|
|
|
|
sub_exploration = {
|
|
|
|
"type": "OrnsteinUhlenbeckNoise",
|
|
|
|
"random_timesteps": random_timesteps,
|
|
|
|
}
|
2020-03-29 00:16:30 +01:00
|
|
|
# TODO(sven): Implement for any action space.
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
self.sub_exploration = from_config(
|
|
|
|
Exploration,
|
|
|
|
sub_exploration,
|
|
|
|
framework=self.framework,
|
|
|
|
action_space=self.action_space,
|
2020-04-03 19:44:25 +02:00
|
|
|
policy_config=self.policy_config,
|
|
|
|
model=self.model,
|
2020-03-29 00:16:30 +01:00
|
|
|
**kwargs
|
|
|
|
)
|
|
|
|
|
|
|
|
# Whether we need to call `self._delayed_on_episode_start` before
|
|
|
|
# the forward pass.
|
|
|
|
self.episode_started = False
|
|
|
|
|
|
|
|
@override(Exploration)
|
|
|
|
def before_compute_actions(
|
|
|
|
self,
|
|
|
|
*,
|
2020-10-07 21:59:14 +02:00
|
|
|
timestep: Optional[int] = None,
|
2021-11-02 12:10:17 +01:00
|
|
|
explore: Optional[bool] = None,
|
2020-10-07 21:59:14 +02:00
|
|
|
tf_sess: Optional["tf.Session"] = None
|
|
|
|
):
|
2020-04-03 19:44:25 +02:00
|
|
|
explore = explore if explore is not None else self.policy_config["explore"]
|
|
|
|
|
2020-03-29 00:16:30 +01:00
|
|
|
# Is this the first forward pass in the new episode? If yes, do the
|
|
|
|
# noise re-sampling and add to weights.
|
|
|
|
if self.episode_started:
|
2020-04-03 19:44:25 +02:00
|
|
|
self._delayed_on_episode_start(explore, tf_sess)
|
2020-03-29 00:16:30 +01:00
|
|
|
|
|
|
|
# Add noise if necessary.
|
|
|
|
if explore and not self.weights_are_currently_noisy:
|
|
|
|
self._add_stored_noise(tf_sess=tf_sess)
|
|
|
|
# Remove noise if necessary.
|
|
|
|
elif not explore and self.weights_are_currently_noisy:
|
|
|
|
self._remove_noise(tf_sess=tf_sess)
|
|
|
|
|
|
|
|
@override(Exploration)
|
2020-10-07 21:59:14 +02:00
|
|
|
def get_exploration_action(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
action_distribution: ActionDistribution,
|
|
|
|
timestep: Union[TensorType, int],
|
|
|
|
explore: Union[TensorType, bool]
|
|
|
|
):
|
2020-03-29 00:16:30 +01:00
|
|
|
# Use our sub-exploration object to handle the final exploration
|
|
|
|
# action (depends on the algo-type/action-space/etc..).
|
|
|
|
return self.sub_exploration.get_exploration_action(
|
2020-04-03 19:44:25 +02:00
|
|
|
action_distribution=action_distribution, timestep=timestep, explore=explore
|
2020-03-29 00:16:30 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
@override(Exploration)
|
|
|
|
def on_episode_start(
|
|
|
|
self,
|
2020-10-07 21:59:14 +02:00
|
|
|
policy: "Policy",
|
2020-03-29 00:16:30 +01:00
|
|
|
*,
|
2020-10-07 21:59:14 +02:00
|
|
|
environment: BaseEnv = None,
|
|
|
|
episode: int = None,
|
|
|
|
tf_sess: Optional["tf.Session"] = None
|
|
|
|
):
|
2020-03-29 00:16:30 +01:00
|
|
|
# We have to delay the noise-adding step by one forward call.
|
|
|
|
# This is due to the fact that the optimizer does it's step right
|
|
|
|
# after the episode was reset (and hence the noise was already added!).
|
|
|
|
# We don't want to update into a noisy net.
|
|
|
|
self.episode_started = True
|
|
|
|
|
2020-04-03 19:44:25 +02:00
|
|
|
def _delayed_on_episode_start(self, explore, tf_sess):
|
2020-03-29 00:16:30 +01:00
|
|
|
# Sample fresh noise and add to weights.
|
2020-04-03 19:44:25 +02:00
|
|
|
if explore:
|
2020-03-29 00:16:30 +01:00
|
|
|
self._sample_new_noise_and_add(tf_sess=tf_sess, override=True)
|
|
|
|
# Only sample, don't apply anything to the weights.
|
|
|
|
else:
|
|
|
|
self._sample_new_noise(tf_sess=tf_sess)
|
|
|
|
self.episode_started = False
|
|
|
|
|
|
|
|
@override(Exploration)
|
|
|
|
def on_episode_end(self, policy, *, environment=None, episode=None, tf_sess=None):
|
|
|
|
# Remove stored noise from weights (only if currently noisy).
|
|
|
|
if self.weights_are_currently_noisy:
|
|
|
|
self._remove_noise(tf_sess=tf_sess)
|
|
|
|
|
|
|
|
@override(Exploration)
|
2020-10-07 21:59:14 +02:00
|
|
|
def postprocess_trajectory(
|
|
|
|
self,
|
|
|
|
policy: "Policy",
|
|
|
|
sample_batch: SampleBatch,
|
|
|
|
tf_sess: Optional["tf.Session"] = None,
|
|
|
|
):
|
2020-03-29 00:16:30 +01:00
|
|
|
noisy_action_dist = noise_free_action_dist = None
|
|
|
|
# Adjust the stddev depending on the action (pi)-distance.
|
|
|
|
# Also see [1] for details.
|
2020-04-09 23:04:21 +02:00
|
|
|
# TODO(sven): Find out whether this can be scrapped by simply using
|
|
|
|
# the `sample_batch` to get the noisy/noise-free action dist.
|
2021-11-02 12:10:17 +01:00
|
|
|
_, _, fetches = policy.compute_actions_from_input_dict(
|
|
|
|
input_dict=sample_batch, explore=self.weights_are_currently_noisy
|
|
|
|
)
|
2020-03-29 00:16:30 +01:00
|
|
|
|
|
|
|
# Categorical case (e.g. DQN).
|
2020-04-06 20:56:16 +02:00
|
|
|
if policy.dist_class in (Categorical, TorchCategorical):
|
2020-04-01 09:43:21 +02:00
|
|
|
action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS])
|
2020-04-09 23:04:21 +02:00
|
|
|
# Deterministic (Gaussian actions, e.g. DDPG).
|
|
|
|
elif policy.dist_class in [Deterministic, TorchDeterministic]:
|
|
|
|
action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS]
|
|
|
|
else:
|
|
|
|
raise NotImplementedError # TODO(sven): Other action-dist cases.
|
2020-03-29 00:16:30 +01:00
|
|
|
|
|
|
|
if self.weights_are_currently_noisy:
|
|
|
|
noisy_action_dist = action_dist
|
|
|
|
else:
|
|
|
|
noise_free_action_dist = action_dist
|
|
|
|
|
2021-11-02 12:10:17 +01:00
|
|
|
_, _, fetches = policy.compute_actions_from_input_dict(
|
2020-03-29 00:16:30 +01:00
|
|
|
input_dict=sample_batch, explore=not self.weights_are_currently_noisy
|
|
|
|
)
|
|
|
|
|
|
|
|
# Categorical case (e.g. DQN).
|
2020-04-06 20:56:16 +02:00
|
|
|
if policy.dist_class in (Categorical, TorchCategorical):
|
2020-04-01 09:43:21 +02:00
|
|
|
action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS])
|
2020-04-09 23:04:21 +02:00
|
|
|
# Deterministic (Gaussian actions, e.g. DDPG).
|
|
|
|
elif policy.dist_class in [Deterministic, TorchDeterministic]:
|
|
|
|
action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS]
|
2020-03-29 00:16:30 +01:00
|
|
|
|
2020-04-03 19:44:25 +02:00
|
|
|
if noisy_action_dist is None:
|
2020-03-29 00:16:30 +01:00
|
|
|
noisy_action_dist = action_dist
|
|
|
|
else:
|
|
|
|
noise_free_action_dist = action_dist
|
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
delta = distance = None
|
2020-03-29 00:16:30 +01:00
|
|
|
# Categorical case (e.g. DQN).
|
2020-04-06 20:56:16 +02:00
|
|
|
if policy.dist_class in (Categorical, TorchCategorical):
|
2020-03-29 00:16:30 +01:00
|
|
|
# Calculate KL-divergence (DKL(clean||noisy)) according to [2].
|
|
|
|
# TODO(sven): Allow KL-divergence to be calculated by our
|
|
|
|
# Distribution classes (don't support off-graph/numpy yet).
|
2020-04-09 23:04:21 +02:00
|
|
|
distance = np.nanmean(
|
2020-03-29 00:16:30 +01:00
|
|
|
np.sum(
|
|
|
|
noise_free_action_dist
|
|
|
|
* np.log(
|
|
|
|
noise_free_action_dist / (noisy_action_dist + SMALL_NUMBER)
|
|
|
|
),
|
|
|
|
1,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
|
|
|
)
|
2021-06-15 13:08:43 +02:00
|
|
|
current_epsilon = self.sub_exploration.get_state(sess=tf_sess)[
|
2020-05-04 23:53:38 +02:00
|
|
|
"cur_epsilon"
|
|
|
|
]
|
2020-03-29 00:16:30 +01:00
|
|
|
delta = -np.log(1 - current_epsilon + current_epsilon / self.action_space.n)
|
2020-04-09 23:04:21 +02:00
|
|
|
elif policy.dist_class in [Deterministic, TorchDeterministic]:
|
|
|
|
# Calculate MSE between noisy and non-noisy output (see [2]).
|
|
|
|
distance = np.sqrt(
|
|
|
|
np.mean(np.square(noise_free_action_dist - noisy_action_dist))
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-06-15 13:08:43 +02:00
|
|
|
current_scale = self.sub_exploration.get_state(sess=tf_sess)["cur_scale"]
|
2020-04-09 23:04:21 +02:00
|
|
|
delta = getattr(self.sub_exploration, "ou_sigma", 0.2) * current_scale
|
|
|
|
|
|
|
|
# Adjust stddev according to the calculated action-distance.
|
|
|
|
if distance <= delta:
|
|
|
|
self.stddev_val *= 1.01
|
|
|
|
else:
|
|
|
|
self.stddev_val /= 1.01
|
2020-03-29 00:16:30 +01:00
|
|
|
|
2021-06-15 13:08:43 +02:00
|
|
|
# Update our state (self.stddev and self.stddev_val).
|
|
|
|
self.set_state(self.get_state(), sess=tf_sess)
|
2020-03-29 00:16:30 +01:00
|
|
|
|
|
|
|
return sample_batch
|
|
|
|
|
|
|
|
def _sample_new_noise(self, *, tf_sess=None):
|
|
|
|
"""Samples new noise and stores it in `self.noise`."""
|
|
|
|
if self.framework == "tf":
|
2020-07-09 10:44:10 +02:00
|
|
|
tf_sess.run(self.tf_sample_new_noise_op)
|
2020-10-02 23:07:44 +02:00
|
|
|
elif self.framework in ["tfe", "tf2"]:
|
2020-07-09 10:44:10 +02:00
|
|
|
self._tf_sample_new_noise_op()
|
2020-03-29 00:16:30 +01:00
|
|
|
else:
|
|
|
|
for i in range(len(self.noise)):
|
|
|
|
self.noise[i] = torch.normal(
|
2020-10-06 20:28:16 +02:00
|
|
|
mean=torch.zeros(self.noise[i].size()), std=self.stddev
|
|
|
|
).to(self.device)
|
2020-03-29 00:16:30 +01:00
|
|
|
|
|
|
|
def _tf_sample_new_noise_op(self):
|
|
|
|
added_noises = []
|
|
|
|
for noise in self.noise:
|
|
|
|
added_noises.append(
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1.assign(
|
2020-03-29 00:16:30 +01:00
|
|
|
noise,
|
2020-06-25 19:01:32 +02:00
|
|
|
tf.random.normal(
|
2020-03-29 00:16:30 +01:00
|
|
|
shape=noise.shape, stddev=self.stddev, dtype=tf.float32
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
|
|
|
)
|
2020-03-29 00:16:30 +01:00
|
|
|
)
|
|
|
|
return tf.group(*added_noises)
|
|
|
|
|
|
|
|
def _sample_new_noise_and_add(self, *, tf_sess=None, override=False):
|
2020-07-09 10:44:10 +02:00
|
|
|
if self.framework == "tf":
|
2020-03-29 00:16:30 +01:00
|
|
|
if override and self.weights_are_currently_noisy:
|
|
|
|
tf_sess.run(self.tf_remove_noise_op)
|
|
|
|
tf_sess.run(self.tf_sample_new_noise_and_add_op)
|
|
|
|
else:
|
|
|
|
if override and self.weights_are_currently_noisy:
|
|
|
|
self._remove_noise()
|
|
|
|
self._sample_new_noise()
|
|
|
|
self._add_stored_noise()
|
|
|
|
|
|
|
|
self.weights_are_currently_noisy = True
|
|
|
|
|
|
|
|
def _add_stored_noise(self, *, tf_sess=None):
|
|
|
|
"""Adds the stored `self.noise` to the model's parameters.
|
|
|
|
|
|
|
|
Note: No new sampling of noise here.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tf_sess (Optional[tf.Session]): The tf-session to use to add the
|
|
|
|
stored noise to the (currently noise-free) weights.
|
2022-06-01 11:27:54 -07:00
|
|
|
override: If True, undo any currently applied noise first,
|
2020-03-29 00:16:30 +01:00
|
|
|
then add the currently stored noise.
|
|
|
|
"""
|
|
|
|
# Make sure we only add noise to currently noise-free weights.
|
|
|
|
assert self.weights_are_currently_noisy is False
|
|
|
|
|
|
|
|
# Add stored noise to the model's parameters.
|
2020-07-09 10:44:10 +02:00
|
|
|
if self.framework == "tf":
|
|
|
|
tf_sess.run(self.tf_add_stored_noise_op)
|
2020-10-02 23:07:44 +02:00
|
|
|
elif self.framework in ["tf2", "tfe"]:
|
2020-07-09 10:44:10 +02:00
|
|
|
self._tf_add_stored_noise_op()
|
2020-03-29 00:16:30 +01:00
|
|
|
else:
|
2020-10-06 20:28:16 +02:00
|
|
|
for var, noise in zip(self.model_variables, self.noise):
|
2020-03-29 00:16:30 +01:00
|
|
|
# Add noise to weights in-place.
|
2020-10-06 20:28:16 +02:00
|
|
|
var.requires_grad = False
|
|
|
|
var.add_(noise)
|
|
|
|
var.requires_grad = True
|
2020-03-29 00:16:30 +01:00
|
|
|
|
|
|
|
self.weights_are_currently_noisy = True
|
|
|
|
|
|
|
|
def _tf_add_stored_noise_op(self):
|
|
|
|
"""Generates tf-op that assigns the stored noise to weights.
|
|
|
|
|
|
|
|
Also used by tf-eager.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
tf.op: The tf op to apply the already stored noise to the NN.
|
|
|
|
"""
|
|
|
|
add_noise_ops = list()
|
|
|
|
for var, noise in zip(self.model_variables, self.noise):
|
2020-06-30 10:13:20 +02:00
|
|
|
add_noise_ops.append(tf1.assign_add(var, noise))
|
2020-03-29 00:16:30 +01:00
|
|
|
ret = tf.group(*tuple(add_noise_ops))
|
2020-06-30 10:13:20 +02:00
|
|
|
with tf1.control_dependencies([ret]):
|
2020-03-29 00:16:30 +01:00
|
|
|
return tf.no_op()
|
|
|
|
|
|
|
|
def _remove_noise(self, *, tf_sess=None):
|
|
|
|
"""
|
|
|
|
Removes the current action noise from the model parameters.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tf_sess (Optional[tf.Session]): The tf-session to use to remove
|
|
|
|
the noise from the (currently noisy) weights.
|
|
|
|
"""
|
|
|
|
# Make sure we only remove noise iff currently noisy.
|
|
|
|
assert self.weights_are_currently_noisy is True
|
|
|
|
|
2020-07-09 10:44:10 +02:00
|
|
|
# Removes the stored noise from the model's parameters.
|
2020-03-29 00:16:30 +01:00
|
|
|
if self.framework == "tf":
|
2020-07-09 10:44:10 +02:00
|
|
|
tf_sess.run(self.tf_remove_noise_op)
|
2020-10-02 23:07:44 +02:00
|
|
|
elif self.framework in ["tf2", "tfe"]:
|
2020-07-09 10:44:10 +02:00
|
|
|
self._tf_remove_noise_op()
|
2020-03-29 00:16:30 +01:00
|
|
|
else:
|
|
|
|
for var, noise in zip(self.model_variables, self.noise):
|
|
|
|
# Remove noise from weights in-place.
|
2020-10-06 20:28:16 +02:00
|
|
|
var.requires_grad = False
|
2020-04-06 20:56:16 +02:00
|
|
|
var.add_(-noise)
|
2020-10-06 20:28:16 +02:00
|
|
|
var.requires_grad = True
|
2020-03-29 00:16:30 +01:00
|
|
|
|
|
|
|
self.weights_are_currently_noisy = False
|
|
|
|
|
|
|
|
def _tf_remove_noise_op(self):
|
|
|
|
"""Generates a tf-op for removing noise from the model's weights.
|
|
|
|
|
|
|
|
Also used by tf-eager.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
tf.op: The tf op to remve the currently stored noise from the NN.
|
|
|
|
"""
|
|
|
|
remove_noise_ops = list()
|
|
|
|
for var, noise in zip(self.model_variables, self.noise):
|
2020-06-30 10:13:20 +02:00
|
|
|
remove_noise_ops.append(tf1.assign_add(var, -noise))
|
2020-03-29 00:16:30 +01:00
|
|
|
ret = tf.group(*tuple(remove_noise_ops))
|
2020-06-30 10:13:20 +02:00
|
|
|
with tf1.control_dependencies([ret]):
|
2020-03-29 00:16:30 +01:00
|
|
|
return tf.no_op()
|
|
|
|
|
|
|
|
@override(Exploration)
|
2021-06-15 13:08:43 +02:00
|
|
|
def get_state(self, sess=None):
|
2020-05-04 23:53:38 +02:00
|
|
|
return {"cur_stddev": self.stddev_val}
|
2021-06-15 13:08:43 +02:00
|
|
|
|
|
|
|
@override(Exploration)
|
|
|
|
def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None:
|
|
|
|
self.stddev_val = state["cur_stddev"]
|
|
|
|
# Set self.stddev to calculated value.
|
|
|
|
if self.framework == "tf":
|
|
|
|
self.stddev.load(self.stddev_val, session=sess)
|
2021-11-02 12:10:17 +01:00
|
|
|
elif isinstance(self.stddev, float):
|
2021-06-15 13:08:43 +02:00
|
|
|
self.stddev = self.stddev_val
|
2021-11-02 12:10:17 +01:00
|
|
|
else:
|
|
|
|
self.stddev.assign(self.stddev_val)
|