[RLlib] Minimal ParamNoise PR. (#7772)

This commit is contained in:
Sven Mika 2020-03-29 00:16:30 +01:00 committed by GitHub
parent 5cebee68d6
commit e4bd5db4d8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 552 additions and 86 deletions

View file

@ -883,7 +883,7 @@ py_test(
py_test(
name = "policy/tests/test_compute_log_likelihoods",
tags = ["policy"],
size = "small",
size = "medium",
srcs = ["policy/tests/test_compute_log_likelihoods.py"]
)
@ -966,7 +966,7 @@ py_test(
py_test(
name = "tests/test_eager_support",
tags = ["tests_dir", "tests_dir_E"],
size = "large",
size = "enormous",
srcs = ["tests/test_eager_support.py"]
)
@ -1361,7 +1361,8 @@ sh_test(
)
py_test(
name = "examples/rock_paper_scissors_multiagent", main = "examples/rock_paper_scissors_multiagent.py",
name = "examples/rock_paper_scissors_multiagent",
main = "examples/rock_paper_scissors_multiagent.py",
tags = ["examples", "examples_R"],
size = "large",
srcs = ["examples/rock_paper_scissors_multiagent.py"],

View file

@ -33,7 +33,7 @@ class TestDQN(unittest.TestCase):
tf_config = config.copy()
tf_config["eager"] = False
trainer = dqn.DQNTrainer(config=tf_config, env="CartPole-v0")
num_iterations = 2
num_iterations = 1
for i in range(num_iterations):
results = trainer.train()
print(results)
@ -44,7 +44,7 @@ class TestDQN(unittest.TestCase):
eager_ctx = eager_mode()
eager_ctx.__enter__()
trainer = dqn.DQNTrainer(config=eager_config, env="CartPole-v0")
num_iterations = 2
num_iterations = 1
for i in range(num_iterations):
results = trainer.train()
print(results)
@ -58,14 +58,21 @@ class TestDQN(unittest.TestCase):
obs = np.array(0)
# Test against all frameworks.
for fw in ["eager", "tf", "torch"]:
for fw in ["tf", "eager", "torch"]:
if fw == "torch":
continue
print("framework={}".format(fw))
config["eager"] = True if fw == "eager" else False
config["use_pytorch"] = True if fw == "torch" else False
eager_mode_ctx = None
if fw == "tf":
assert not tf.executing_eagerly()
else:
eager_mode_ctx = eager_mode()
eager_mode_ctx.__enter__()
config["eager"] = fw == "eager"
config["use_pytorch"] = fw == "torch"
# Default EpsilonGreedy setup.
trainer = dqn.DQNTrainer(config=config, env="FrozenLake-v0")
@ -122,5 +129,6 @@ class TestDQN(unittest.TestCase):
if __name__ == "__main__":
import unittest
unittest.main(verbosity=1)
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -419,7 +419,8 @@ class Trainer(Trainable):
config = config or {}
if tf and config.get("eager"):
tf.enable_eager_execution()
if not tf.executing_eagerly():
tf.enable_eager_execution()
logger.info("Executing eagerly, with eager_tracing={}".format(
"True" if config.get("eager_tracing") else "False"))

View file

@ -151,6 +151,9 @@ class MultiAgentSampleBatchBuilder:
"from a single trajectory.", pre_batch)
post_batches[agent_id] = policy.postprocess_trajectory(
pre_batch, other_batches, episode)
# Call the Policy's Exploration's postprocess method.
policy.exploration.postprocess_trajectory(
policy, post_batches[agent_id], getattr(policy, "_sess", None))
if log_once("after_post"):
logger.info(

View file

@ -306,6 +306,14 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
def new_episode():
episode = MultiAgentEpisode(policies, policy_mapping_fn,
get_batch_builder, extra_batch_callback)
# Call each policy's Exploration.on_episode_start method.
for p in policies.values():
p.exploration.on_episode_start(
policy=p,
environment=base_env,
episode=episode,
tf_sess=getattr(p, "_sess", None))
# Call custom on_episode_start callback.
if callbacks.get("on_episode_start"):
callbacks["on_episode_start"]({
"env": base_env,
@ -492,6 +500,14 @@ def _process_observations(base_env, policies, batch_builder_pool,
if all_done:
# Handle episode termination
batch_builder_pool.append(episode.batch_builder)
# Call each policy's Exploration.on_episode_end method.
for p in policies.values():
p.exploration.on_episode_end(
policy=p,
environment=base_env,
episode=episode,
tf_sess=getattr(p, "_sess", None))
# Call custom on_episode_end callback.
if callbacks.get("on_episode_end"):
callbacks["on_episode_end"]({
"env": base_env,
@ -558,15 +574,15 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
policy = _get_or_raise(policies, policy_id)
if builder and (policy.compute_actions.__code__ is
TFPolicy.compute_actions.__code__):
rnn_in_cols = _to_column_format(rnn_in)
obs_batch = [t.obs for t in eval_data]
state_batches = _to_column_format(rnn_in)
# TODO(ekl): how can we make info batch available to TF code?
# TODO(sven): Return dict from _build_compute_actions.
# it's becoming more and more unclear otherwise, what's where in
# the return tuple.
pending_fetches[policy_id] = policy._build_compute_actions(
builder,
obs_batch=[t.obs for t in eval_data],
state_batches=rnn_in_cols,
obs_batch=obs_batch,
state_batches=state_batches,
prev_action_batch=[t.prev_action for t in eval_data],
prev_reward_batch=[t.prev_reward for t in eval_data],
timestep=policy.global_timestep)

View file

@ -261,15 +261,16 @@ def build_eager_tf_policy(name,
@override(Policy)
def postprocess_trajectory(self,
samples,
sample_batch,
other_agent_batches=None,
episode=None):
assert tf.executing_eagerly()
# Call super's postprocess_trajectory first.
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
if postprocess_fn:
return postprocess_fn(self, samples, other_agent_batches,
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
else:
return samples
return sample_batch
@override(Policy)
@convert_eager_inputs
@ -305,6 +306,8 @@ def build_eager_tf_policy(name,
explore = explore if explore is not None else \
self.config["explore"]
timestep = timestep if timestep is not None else \
self.global_timestep
# TODO: remove python side effect to cull sources of bugs.
self._is_training = False
@ -339,19 +342,20 @@ def build_eager_tf_policy(name,
self.action_space,
explore,
self.config,
timestep=timestep
if timestep is not None else self.global_timestep)
timestep=timestep)
# Use Exploration object.
else:
with tf.variable_creator_scope(_disallow_var_creation):
# Call the exploration before_compute_actions hook.
self.exploration.before_compute_actions(timestep=timestep)
model_out, state_out = self.model(input_dict,
state_batches, seq_lens)
action, logp = self.exploration.get_exploration_action(
model_out,
self.dist_class,
self.model,
timestep=timestep
if timestep is not None else self.global_timestep,
timestep=timestep,
explore=explore)
extra_fetches = {}

View file

@ -284,27 +284,6 @@ class Policy(metaclass=ABCMeta):
"""
return self.exploration.get_info()
@DeveloperAPI
def get_exploration_state(self):
"""Returns the current exploration state of this policy.
This state depends on the policy's Exploration object.
Returns:
any: Serializable copy or view of the current exploration state.
"""
raise NotImplementedError
@DeveloperAPI
def set_exploration_state(self, exploration_state):
"""Sets the current exploration state of this Policy.
Arguments:
exploration_state (any): Serializable copy or view of the new
exploration state.
"""
raise NotImplementedError
@DeveloperAPI
def is_recurrent(self):
"""Whether this Policy holds a recurrent Model.

View file

@ -45,14 +45,15 @@ def do_test_log_likelihood(run,
config["use_pytorch"] = fw == "torch"
eager_ctx = None
if fw == "eager":
if fw == "tf":
assert not tf.executing_eagerly()
elif fw == "eager":
eager_ctx = eager_mode()
eager_ctx.__enter__()
assert tf.executing_eagerly()
elif fw == "tf":
assert not tf.executing_eagerly()
trainer = run(config=config, env=env)
policy = trainer.get_policy()
vars = policy.get_weights()
# Sample n actions, then roughly check their logp against their

View file

@ -253,6 +253,7 @@ class TFPolicy(Policy):
timestep=None,
**kwargs):
explore = explore if explore is not None else self.config["explore"]
builder = TFRunBuilder(self._sess, "compute_actions")
fetches = self._build_compute_actions(
builder,
@ -528,11 +529,16 @@ class TFPolicy(Policy):
explore = explore if explore is not None else self.config["explore"]
# Call the exploration before_compute_actions hook.
self.exploration.before_compute_actions(
timestep=self.global_timestep, tf_sess=self.get_session())
state_batches = state_batches or []
if len(self._state_inputs) != len(state_batches):
raise ValueError(
"Must pass in RNN state batches for placeholders {}, got {}".
format(self._state_inputs, state_batches))
builder.add_feed_dict(self.extra_compute_action_feed_dict())
builder.add_feed_dict({self._obs_input: obs_batch})
if state_batches:

View file

@ -151,10 +151,10 @@ def build_tf_policy(name,
sample_batch,
other_agent_batches=None,
episode=None):
if not postprocess_fn:
return sample_batch
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
if postprocess_fn:
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
return sample_batch
@override(TFPolicy)
def optimizer(self):

View file

@ -71,6 +71,7 @@ class TorchPolicy(Policy):
**kwargs):
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
with torch.no_grad():
input_dict = self._lazy_tensor_dict({
@ -81,6 +82,10 @@ class TorchPolicy(Policy):
if prev_reward_batch:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
state_batches = [self._convert_to_tensor(s) for s in state_batches]
# Call the exploration before_compute_actions hook.
self.exploration.before_compute_actions(timestep=timestep)
model_out = self.model(input_dict, state_batches,
self._convert_to_tensor([1]))
logits, state = model_out
@ -88,8 +93,7 @@ class TorchPolicy(Policy):
actions, logp = \
self.exploration.get_exploration_action(
logits, self.dist_class, self.model,
timestep if timestep is not None else
self.global_timestep, explore)
timestep, explore)
input_dict[SampleBatch.ACTIONS] = actions
extra_action_out = self.extra_action_out(input_dict, state_batches,
@ -100,8 +104,8 @@ class TorchPolicy(Policy):
ACTION_PROB: np.exp(logp),
ACTION_LOGP: logp
})
return convert_to_non_torch_type(
(actions, state, extra_action_out))
return convert_to_non_torch_type((actions, state,
extra_action_out))
@override(Policy)
def compute_log_likelihoods(self,

View file

@ -86,10 +86,8 @@ def build_torch_policy(name,
self.config["model"],
framework="torch")
TorchPolicy.__init__(
self, obs_space, action_space, config, self.model,
loss_fn, self.dist_class
)
TorchPolicy.__init__(self, obs_space, action_space, config,
self.model, loss_fn, self.dist_class)
if after_init:
after_init(self, obs_space, action_space, config)
@ -117,17 +115,18 @@ def build_torch_policy(name,
return TorchPolicy.extra_grad_process(self)
@override(TorchPolicy)
def extra_action_out(self, input_dict, state_batches, model,
def extra_action_out(self,
input_dict,
state_batches,
model,
action_dist=None):
with torch.no_grad():
if extra_action_out_fn:
stats_dict = extra_action_out_fn(
self, input_dict, state_batches, model, action_dist
)
self, input_dict, state_batches, model, action_dist)
else:
stats_dict = TorchPolicy.extra_action_out(
self, input_dict, state_batches, model, action_dist
)
self, input_dict, state_batches, model, action_dist)
return convert_to_non_torch_type(stats_dict)
@override(TorchPolicy)

View file

@ -87,8 +87,9 @@ class TestEagerSupport(unittest.TestCase):
},
})
def test_sac(self):
check_support("SAC", {"num_workers": 0})
# TODO(sven): Add this once SAC supports eager.
# def test_sac(self):
# check_support("SAC", {"num_workers": 0, "learning_starts": 0})
if __name__ == "__main__":

View file

@ -1,12 +1,15 @@
from gym.spaces import Space
from typing import Union
from ray.rllib.utils.framework import check_framework, try_import_tf, \
TensorType
from ray.rllib.models.modelv2 import ModelV2
from typing import Union
from ray.rllib.utils.annotations import DeveloperAPI
tf = try_import_tf()
@DeveloperAPI
class Exploration:
"""Implements an exploration strategy for Policies.
@ -32,6 +35,24 @@ class Exploration:
self.worker_index = worker_index
self.framework = check_framework(framework)
@DeveloperAPI
def before_compute_actions(self,
*,
timestep=None,
explore=None,
tf_sess=None,
**kwargs):
"""Hook for preparations before policy.compute_actions() is called.
Args:
timestep (Optional[TensorType]): An optional timestep tensor.
explore (Optional[TensorType]): An optional explore boolean flag.
tf_sess (Optional[tf.Session]): The tf-session object to use.
**kwargs: Forward compatibility kwargs.
"""
pass
@DeveloperAPI
def get_exploration_action(self,
distribution_inputs: TensorType,
action_dist_class: type,
@ -64,25 +85,55 @@ class Exploration:
"""
pass
def get_loss_exploration_term(self,
model_output: TensorType,
model: ModelV2,
action_dist: type,
action_sample: TensorType = None):
"""Returns an extra loss term to be added to a loss.
@DeveloperAPI
def on_episode_start(self,
policy,
*,
environment=None,
episode=None,
tf_sess=None):
"""Handles necessary exploration logic at the beginning of an episode.
Args:
model_output (TensorType): The Model's output Tensor(s).
model (ModelV2): The Model object.
action_dist: The ActionDistribution object resulting from
`model_output`. TODO: Or the class?
action_sample (TensorType): An optional action sample.
Returns:
TensorType: The extra loss term to add to the loss.
policy (Policy): The Policy object that holds this Exploration.
environment (BaseEnv): The environment object we are acting in.
episode (int): The number of the episode that is starting.
tf_sess (Optional[tf.Session]): In case of tf, the session object.
"""
pass # TODO(sven): implement for some example Exploration class.
pass
@DeveloperAPI
def on_episode_end(self,
policy,
*,
environment=None,
episode=None,
tf_sess=None):
"""Handles necessary exploration logic at the end of an episode.
Args:
policy (Policy): The Policy object that holds this Exploration.
environment (BaseEnv): The environment object we are acting in.
episode (int): The number of the episode that is starting.
tf_sess (Optional[tf.Session]): In case of tf, the session object.
"""
pass
@DeveloperAPI
def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
"""Handles post-processing of done episode trajectories.
Changes the given batch in place. This callback is invoked by the
sampler after policy.postprocess_trajectory() is called.
Args:
policy (Policy): The owning policy object.
sample_batch (SampleBatch): The SampleBatch object to post-process.
tf_sess (Optional[tf.Session]): An optional tf.Session object.
"""
return sample_batch
@DeveloperAPI
def get_info(self):
"""Returns a description of the current exploration state.

View file

@ -0,0 +1,382 @@
from gym.spaces import Discrete
import numpy as np
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.utils.annotations import override
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.framework import get_variable
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.numpy import softmax, SMALL_NUMBER
tf = try_import_tf()
torch, _ = try_import_torch()
class ParameterNoise(Exploration):
"""An exploration that changes a Model's parameters.
Implemented based on:
[1] https://blog.openai.com/better-exploration-with-parameter-noise/
[2] https://arxiv.org/pdf/1706.01905.pdf
At the beginning of an episode, Gaussian noise is added to all weights
of the model. At the end of the episode, the noise is undone and an action
diff (pi-delta) is calculated, from which we determine the changes in the
noise's stddev for the next episode.
"""
def __init__(self,
action_space,
*,
framework: str,
policy_config: dict,
model: ModelV2,
initial_stddev=1.0,
random_timesteps=10000,
sub_exploration=None,
**kwargs):
"""Initializes a ParameterNoise Exploration object.
Args:
initial_stddev (float): The initial stddev to use for the noise.
random_timesteps (int): The number of timesteps to act completely
randomly (see [1]).
sub_exploration (Optional[dict]): Optional sub-exploration config.
None for auto-detection/setup.
"""
assert framework is not None
super().__init__(action_space, framework=framework, **kwargs)
# TODO(sven): Move these to base-Exploration class.
self.policy_config = policy_config,
self.model = model,
self.stddev = get_variable(
initial_stddev, framework=self.framework, tf_name="stddev")
self.stddev_val = initial_stddev # Out-of-graph tf value holder.
# The weight variables of the Model where noise should be applied to.
# This excludes any variable, whose name contains "LayerNorm" (those
# are BatchNormalization layers, which should not be perturbed).
self.model_variables = [
v for v in self.model.variables() if "LayerNorm" not in v.name
]
# Our noise to be added to the weights. Each item in `self.noise`
# corresponds to one Model variable and holding the Gaussian noise to
# be added to that variable (weight).
self.noise = []
for var in self.model_variables:
self.noise.append(
get_variable(
np.zeros(var.shape, dtype=np.float32),
framework=self.framework,
tf_name=var.name.split(":")[0] + "_noisy"))
# tf-specific ops to sample, assign and remove noise.
if self.framework == "tf" and not tf.executing_eagerly():
self.tf_sample_new_noise_op = \
self._tf_sample_new_noise_op()
self.tf_add_stored_noise_op = \
self._tf_add_stored_noise_op()
self.tf_remove_noise_op = \
self._tf_remove_noise_op()
# Create convenience sample+add op for tf.
with tf.control_dependencies([self.tf_sample_new_noise_op]):
add_op = self._tf_add_stored_noise_op()
with tf.control_dependencies([add_op]):
self.tf_sample_new_noise_and_add_op = tf.no_op()
# Whether the Model's weights currently have noise added or not.
self.weights_are_currently_noisy = False
# Auto-detection of underlying exploration functionality.
if sub_exploration is None:
# For discrete action spaces, use an underlying EpsilonGreedy with
# a special schedule.
if isinstance(self.action_space, Discrete):
sub_exploration = {
"type": "EpsilonGreedy",
"epsilon_schedule": {
"type": "PiecewiseSchedule",
# Step function (see [2]).
"endpoints": [(0, 1.0), (random_timesteps + 1, 1.0),
(random_timesteps + 2, 0.01)],
"outside_value": 0.01
}
}
# TODO(sven): Implement for any action space.
else:
raise NotImplementedError
self.sub_exploration = from_config(
Exploration,
sub_exploration,
framework=self.framework,
action_space=self.action_space,
**kwargs)
# Store the default setting for `explore`.
self.default_explore = policy_config["explore"]
# Whether we need to call `self._delayed_on_episode_start` before
# the forward pass.
self.episode_started = False
@override(Exploration)
def before_compute_actions(self,
*,
timestep=None,
explore=None,
tf_sess=None):
# Is this the first forward pass in the new episode? If yes, do the
# noise re-sampling and add to weights.
if self.episode_started:
self._delayed_on_episode_start(tf_sess)
explore = explore if explore is not None else \
self.policy_config["explore"]
# Add noise if necessary.
if explore and not self.weights_are_currently_noisy:
self._add_stored_noise(tf_sess=tf_sess)
# Remove noise if necessary.
elif not explore and self.weights_are_currently_noisy:
self._remove_noise(tf_sess=tf_sess)
@override(Exploration)
def get_exploration_action(self,
*,
distribution_inputs,
action_dist_class,
timestep,
explore=True):
# Use our sub-exploration object to handle the final exploration
# action (depends on the algo-type/action-space/etc..).
return self.sub_exploration.get_exploration_action(
distribution_inputs=distribution_inputs,
action_dist_class=action_dist_class,
timestep=timestep,
explore=explore)
@override(Exploration)
def on_episode_start(self,
policy,
*,
environment=None,
episode=None,
tf_sess=None):
# We have to delay the noise-adding step by one forward call.
# This is due to the fact that the optimizer does it's step right
# after the episode was reset (and hence the noise was already added!).
# We don't want to update into a noisy net.
self.episode_started = True
def _delayed_on_episode_start(self, tf_sess):
# Sample fresh noise and add to weights.
if self.default_explore:
self._sample_new_noise_and_add(tf_sess=tf_sess, override=True)
# Only sample, don't apply anything to the weights.
else:
self._sample_new_noise(tf_sess=tf_sess)
self.episode_started = False
@override(Exploration)
def on_episode_end(self,
policy,
*,
environment=None,
episode=None,
tf_sess=None):
# Remove stored noise from weights (only if currently noisy).
if self.weights_are_currently_noisy:
self._remove_noise(tf_sess=tf_sess)
@override(Exploration)
def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
noisy_action_dist = noise_free_action_dist = None
# Adjust the stddev depending on the action (pi)-distance.
# Also see [1] for details.
distribution = policy.compute_action_distribution(
obs_batch=sample_batch[SampleBatch.CUR_OBS],
# TODO(sven): What about state-ins and seq-lens?
prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS),
explore=self.weights_are_currently_noisy)
# Categorical case (e.g. DQN).
if isinstance(distribution, Categorical):
action_dist = softmax(distribution.inputs)
else: # TODO(sven): Other action-dist cases.
raise NotImplementedError
if self.weights_are_currently_noisy:
noisy_action_dist = action_dist
else:
noise_free_action_dist = action_dist
distribution = policy.compute_action_distribution(
obs_batch=sample_batch[SampleBatch.CUR_OBS],
# TODO(sven): What about state-ins and seq-lens?
prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS),
explore=not self.weights_are_currently_noisy)
# Categorical case (e.g. DQN).
if isinstance(distribution, Categorical):
action_dist = softmax(distribution.inputs)
if not self.weights_are_currently_noisy:
noisy_action_dist = action_dist
else:
noise_free_action_dist = action_dist
# Categorical case (e.g. DQN).
if isinstance(distribution, Categorical):
# Calculate KL-divergence (DKL(clean||noisy)) according to [2].
# TODO(sven): Allow KL-divergence to be calculated by our
# Distribution classes (don't support off-graph/numpy yet).
kl_divergence = np.nanmean(
np.sum(
noise_free_action_dist *
np.log(noise_free_action_dist /
(noisy_action_dist + SMALL_NUMBER)), 1))
current_epsilon = self.sub_exploration.get_info()["cur_epsilon"]
if tf_sess is not None:
current_epsilon = tf_sess.run(current_epsilon)
delta = -np.log(1 - current_epsilon +
current_epsilon / self.action_space.n)
if kl_divergence <= delta:
self.stddev_val *= 1.01
else:
self.stddev_val /= 1.01
# Set self.stddev to calculated value.
if self.framework == "tf":
self.stddev.load(self.stddev_val, session=tf_sess)
else:
self.stddev = self.stddev_val
return sample_batch
def _sample_new_noise(self, *, tf_sess=None):
"""Samples new noise and stores it in `self.noise`."""
if self.framework == "tf":
if tf.executing_eagerly():
self._tf_sample_new_noise_op()
else:
tf_sess.run(self.tf_sample_new_noise_op)
else:
for i in range(len(self.noise)):
self.noise[i] = torch.normal(
0.0, self.stddev, size=self.noise[i].size)
def _tf_sample_new_noise_op(self):
added_noises = []
for noise in self.noise:
added_noises.append(
tf.assign(
noise,
tf.random_normal(
shape=noise.shape,
stddev=self.stddev,
dtype=tf.float32)))
return tf.group(*added_noises)
def _sample_new_noise_and_add(self, *, tf_sess=None, override=False):
if self.framework == "tf" and not tf.executing_eagerly():
if override and self.weights_are_currently_noisy:
tf_sess.run(self.tf_remove_noise_op)
tf_sess.run(self.tf_sample_new_noise_and_add_op)
else:
if override and self.weights_are_currently_noisy:
self._remove_noise()
self._sample_new_noise()
self._add_stored_noise()
self.weights_are_currently_noisy = True
def _add_stored_noise(self, *, tf_sess=None):
"""Adds the stored `self.noise` to the model's parameters.
Note: No new sampling of noise here.
Args:
tf_sess (Optional[tf.Session]): The tf-session to use to add the
stored noise to the (currently noise-free) weights.
override (bool): If True, undo any currently applied noise first,
then add the currently stored noise.
"""
# Make sure we only add noise to currently noise-free weights.
assert self.weights_are_currently_noisy is False
if self.framework == "tf":
if tf.executing_eagerly():
self._tf_add_stored_noise_op()
else:
tf_sess.run(self.tf_add_stored_noise_op)
# Add stored noise to the model's parameters.
else:
for i in range(len(self.noise)):
# Add noise to weights in-place.
torch.add_(self.model_variables[i], self.noise[i])
self.weights_are_currently_noisy = True
def _tf_add_stored_noise_op(self):
"""Generates tf-op that assigns the stored noise to weights.
Also used by tf-eager.
Returns:
tf.op: The tf op to apply the already stored noise to the NN.
"""
add_noise_ops = list()
for var, noise in zip(self.model_variables, self.noise):
add_noise_ops.append(tf.assign_add(var, noise))
ret = tf.group(*tuple(add_noise_ops))
with tf.control_dependencies([ret]):
return tf.no_op()
def _remove_noise(self, *, tf_sess=None):
"""
Removes the current action noise from the model parameters.
Args:
tf_sess (Optional[tf.Session]): The tf-session to use to remove
the noise from the (currently noisy) weights.
"""
# Make sure we only remove noise iff currently noisy.
assert self.weights_are_currently_noisy is True
if self.framework == "tf":
if tf.executing_eagerly():
self._tf_remove_noise_op()
else:
tf_sess.run(self.tf_remove_noise_op)
else:
# Removes the stored noise from the model's parameters.
for var, noise in zip(self.model_variables, self.noise):
# Remove noise from weights in-place.
torch.add_(var, -noise)
self.weights_are_currently_noisy = False
def _tf_remove_noise_op(self):
"""Generates a tf-op for removing noise from the model's weights.
Also used by tf-eager.
Returns:
tf.op: The tf op to remve the currently stored noise from the NN.
"""
remove_noise_ops = list()
for var, noise in zip(self.model_variables, self.noise):
remove_noise_ops.append(tf.assign_add(var, -noise))
ret = tf.group(*tuple(remove_noise_ops))
with tf.control_dependencies([ret]):
return tf.no_op()
@override(Exploration)
def get_info(self):
return {"cur_stddev": self.stddev}

View file

@ -12,7 +12,9 @@ import ray.rllib.agents.impala as impala
import ray.rllib.agents.pg as pg
import ray.rllib.agents.ppo as ppo
import ray.rllib.agents.sac as sac
from ray.rllib.utils import check
from ray.rllib.utils import check, try_import_tf
tf = try_import_tf()
def do_test_explorations(run,
@ -53,6 +55,9 @@ def do_test_explorations(run,
eager_mode_ctx = eager_mode()
if fw == "eager":
eager_mode_ctx.__enter__()
assert tf.executing_eagerly()
elif fw == "tf":
assert not tf.executing_eagerly()
trainer = run(config=config, env=env)

View file

@ -134,7 +134,12 @@ def get_variable(value, framework="tf", tf_name="unnamed-variable"):
"""
if framework == "tf":
import tensorflow as tf
return tf.compat.v1.get_variable(tf_name, initializer=value)
dtype = getattr(
value, "dtype", tf.float32
if isinstance(value, float) else tf.int32
if isinstance(value, int) else None)
return tf.compat.v1.get_variable(
tf_name, initializer=value, dtype=dtype)
# torch or None: Return python primitive.
return value