[RLlib] DDPG and SAC eager support (preparation for tf2.x) (#9204)

This commit is contained in:
Sven Mika 2020-07-08 16:12:20 +02:00 committed by GitHub
parent 42f8f16c04
commit 4da0e542d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 383 additions and 370 deletions

View file

@ -201,7 +201,7 @@ build_sphinx_docs() {
if [ "${OSTYPE}" = msys ]; then
echo "WARNING: Documentation not built on Windows due to currently-unresolved issues"
else
sphinx-build -q -E -T -b html source _build/html
sphinx-build -q -E -W -T -b html source _build/html
fi
)
}

View file

@ -210,12 +210,13 @@ PettingZoo Multi-Agent Environments
`PettingZoo <https://github.com/PettingZoo-Team/PettingZoo>`__ is a repository of over 50 diverse multi-agent environments. However, the API is note directly compatible with rllib, but it can be converted into an rllib MultiAgentEnv like in this example
.. code-block:: python
from ray.tune.registry import register_env
# import the pettingzoo environment
from pettingzoo.gamma import prison_v0
# import rllib pettingzoo interface
from ray.rllib.env import PettingZooEnv
# define how to make the environment. This way takes an optinoal environment config, num_floors
# define how to make the environment. This way takes an optional environment config, num_floors
env_creator = lambda config: prison_v0.env(num_floors=config.get("num_floors", 4))
# register that way to make the environment under an rllib name
register_env('prison', lambda config: PettingZooEnv(env_creator(config)))

View file

@ -38,7 +38,7 @@ Then, you can try out training in the following equivalent ways:
from ray import tune
from ray.rllib.agents.ppo import PPOTrainer
tune.run(PPOTrainer, config={"env": "CartPole-v0"}) # "log_level": "INFO" for verbose,
# "framework": "tfe" for tf-eager execution,
# "framework": "tfe" for tf-eager,
# "framework": "torch" for PyTorch
Next, we'll cover three key concepts in RLlib: Policies, Samples, and Trainers.

View file

@ -205,16 +205,16 @@ def update_backend_config(backend_tag, config_options):
backend_tag(str): A registered backend.
config_options(dict): Backend config options to update.
Supported options:
- "num_replicas": number of worker processes to start up that \
will handle requests to this backend.
- "max_batch_size": the maximum number of requests that will \
be processed in one batch by this backend.
- "batch_wait_timeout": time in seconds that backend replicas \
will wait for a full batch of requests before \
processing a partial batch.
- "max_concurrent_queries": the maximum number of queries \
that will be sent to a replica of this backend \
without receiving a response.
- "num_replicas": number of worker processes to start up that
will handle requests to this backend.
- "max_batch_size": the maximum number of requests that will
be processed in one batch by this backend.
- "batch_wait_timeout": time in seconds that backend replicas
will wait for a full batch of requests before
processing a partial batch.
- "max_concurrent_queries": the maximum number of queries
that will be sent to a replica of this backend
without receiving a response.
"""
if not isinstance(config_options, dict):
raise ValueError("config_options must be a dictionary.")
@ -252,16 +252,16 @@ def create_backend(backend_tag,
@ray.remote decorator for the backend actor.
config (optional): configuration options for this backend.
Supported options:
- "num_replicas": number of worker processes to start up that \
will handle requests to this backend.
- "max_batch_size": the maximum number of requests that will \
be processed in one batch by this backend.
- "batch_wait_timeout": time in seconds that backend replicas \
will wait for a full batch of requests before \
processing a partial batch.
- "max_concurrent_queries": the maximum number of queries \
that will be sent to a replica of this backend \
without receiving a response.
- "num_replicas": number of worker processes to start up that will
handle requests to this backend.
- "max_batch_size": the maximum number of requests that will
be processed in one batch by this backend.
- "batch_wait_timeout": time in seconds that backend replicas
will wait for a full batch of requests before processing a
partial batch.
- "max_concurrent_queries": the maximum number of queries that will
be sent to a replica of this backend without receiving a
response.
"""
if config is None:
config = {}

View file

@ -1148,10 +1148,21 @@ py_test(
)
py_test(
name = "tests/test_eager_support",
name = "tests/test_eager_support_pg",
main = "tests/test_eager_support.py",
tags = ["tests_dir", "tests_dir_E"],
size = "enormous",
srcs = ["tests/test_eager_support.py"]
size = "large",
srcs = ["tests/test_eager_support.py"],
args = ["TestEagerSupportPG"]
)
py_test(
name = "tests/test_eager_support_off_policy",
main = "tests/test_eager_support.py",
tags = ["tests_dir", "tests_dir_E"],
size = "large",
srcs = ["tests/test_eager_support.py"],
args = ["TestEagerSupportOffPolicy"]
)
py_test(
@ -1269,6 +1280,13 @@ py_test(
srcs = ["tests/test_nested_observation_spaces.py"]
)
py_test(
name = "tests/test_pettingzoo_env",
tags = ["tests_dir", "tests_dir_P"],
size = "medium",
srcs = ["tests/test_pettingzoo_env.py"]
)
py_test(
name = "tests/test_reproducibility",
tags = ["tests_dir", "tests_dir_R"],
@ -1311,17 +1329,30 @@ py_test(
)
py_test(
name = "tests/test_pettingzoo_env",
name = "tests/test_supported_spaces_pg",
main = "tests/test_supported_spaces.py",
tags = ["tests_dir", "tests_dir_S"],
size = "medium",
srcs = ["tests/test_pettingzoo_env.py"]
size = "enormous",
srcs = ["tests/test_supported_spaces.py"],
args = ["TestSupportedSpacesPG"]
)
py_test(
name = "tests/test_supported_spaces",
name = "tests/test_supported_spaces_off_policy",
main = "tests/test_supported_spaces.py",
tags = ["tests_dir", "tests_dir_S"],
size = "enormous",
srcs = ["tests/test_supported_spaces.py"]
srcs = ["tests/test_supported_spaces.py"],
args = ["TestSupportedSpacesOffPolicy"]
)
py_test(
name = "tests/test_supported_spaces_evolution_algos",
main = "tests/test_supported_spaces.py",
tags = ["tests_dir", "tests_dir_S"],
size = "large",
srcs = ["tests/test_supported_spaces.py"],
args = ["TestSupportedSpacesEvolutionAlgos"]
)
# --------------------------------------------------------------------

View file

@ -55,7 +55,7 @@ def postprocess_advantages(policy,
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
next_state.append(sample_batch["state_out_{}".format(i)][-1])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],

View file

@ -24,31 +24,34 @@ class TestA2C(unittest.TestCase):
num_iterations = 1
# Test against all frameworks.
for fw in framework_iterator(config, ("tf", "torch")):
config["sample_async"] = fw == "tf"
for fw in framework_iterator(config):
config["sample_async"] = fw in ["tf", "tfe"]
for env in ["PongDeterministic-v0"]:
trainer = a3c.A2CTrainer(config=config, env=env)
for i in range(num_iterations):
results = trainer.train()
print(results)
check_compute_single_action(trainer)
trainer.stop()
def test_a2c_exec_impl(ray_start_regular):
config = {"min_iter_time_s": 0}
for _ in framework_iterator(config, ("tf", "torch")):
for _ in framework_iterator(config):
trainer = a3c.A2CTrainer(env="CartPole-v0", config=config)
assert isinstance(trainer.train(), dict)
check_compute_single_action(trainer)
trainer.stop()
def test_a2c_exec_impl_microbatch(ray_start_regular):
config = {
"min_iter_time_s": 0,
"microbatch_size": 10,
}
for _ in framework_iterator(config, ("tf", "torch")):
for _ in framework_iterator(config):
trainer = a3c.A2CTrainer(env="CartPole-v0", config=config)
assert isinstance(trainer.train(), dict)
check_compute_single_action(trainer)
trainer.stop()
if __name__ == "__main__":

View file

@ -24,7 +24,7 @@ class TestA3C(unittest.TestCase):
num_iterations = 1
# Test against all frameworks.
for fw in framework_iterator(config, ("tf", "torch")):
for fw in framework_iterator(config):
config["sample_async"] = fw == "tf"
for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]:
trainer = a3c.A3CTrainer(config=config, env=env)
@ -32,6 +32,7 @@ class TestA3C(unittest.TestCase):
results = trainer.train()
print(results)
check_compute_single_action(trainer)
trainer.stop()
if __name__ == "__main__":

View file

@ -27,7 +27,14 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
},
)
def validate_config(config):
if config.get("framework") == "tfe":
raise ValueError("APEX_DDPG does not support tf-eager yet!")
ApexDDPGTrainer = DDPGTrainer.with_updates(
name="APEX_DDPG",
default_config=APEX_DDPG_DEFAULT_CONFIG,
validate_config=validate_config,
execution_plan=apex_execution_plan)

View file

@ -5,8 +5,6 @@ from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.ddpg.ddpg_tf_policy import DDPGTFPolicy
from ray.rllib.utils.deprecation import deprecation_warning, \
DEPRECATED_VALUE
from ray.rllib.utils.exploration.per_worker_ornstein_uhlenbeck_noise import \
PerWorkerOrnsteinUhlenbeckNoise
logger = logging.getLogger(__name__)
@ -129,7 +127,7 @@ DEFAULT_CONFIG = with_common_config({
# Weights for L2 regularization
"l2_reg": 1e-6,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": None,
"grad_clip": None,
# How many steps of the model to sample before learning starts.
"learning_starts": 1500,
# Update the replay buffer with this many samples at once. Note that this
@ -151,7 +149,7 @@ DEFAULT_CONFIG = with_common_config({
"min_iter_time_s": 1,
# Deprecated keys.
"parameter_noise": DEPRECATED_VALUE,
"grad_norm_clipping": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# yapf: enable
@ -164,41 +162,12 @@ def validate_config(config):
"was specified.")
config["use_state_preprocessor"] = True
# TODO(sven): Remove at some point.
# Backward compatibility of noise-based exploration config.
schedule_max_timesteps = None
if config.get("schedule_max_timesteps", DEPRECATED_VALUE) != \
DEPRECATED_VALUE:
deprecation_warning("schedule_max_timesteps",
"exploration_config.scale_timesteps")
schedule_max_timesteps = config["schedule_max_timesteps"]
if config.get("exploration_final_scale", DEPRECATED_VALUE) != \
DEPRECATED_VALUE:
deprecation_warning("exploration_final_scale",
"exploration_config.final_scale")
if isinstance(config["exploration_config"], dict):
config["exploration_config"]["final_scale"] = \
config.pop("exploration_final_scale")
if config.get("exploration_fraction", DEPRECATED_VALUE) != \
DEPRECATED_VALUE:
assert schedule_max_timesteps is not None
deprecation_warning("exploration_fraction",
"exploration_config.scale_timesteps")
if isinstance(config["exploration_config"], dict):
config["exploration_config"]["scale_timesteps"] = config.pop(
"exploration_fraction") * schedule_max_timesteps
if config.get("per_worker_exploration", DEPRECATED_VALUE) != \
DEPRECATED_VALUE:
deprecation_warning(
"per_worker_exploration",
"exploration_config.type=PerWorkerOrnsteinUhlenbeckNoise")
if isinstance(config["exploration_config"], dict):
config["exploration_config"]["type"] = \
PerWorkerOrnsteinUhlenbeckNoise
if config.get("grad_norm_clipping", DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning("grad_norm_clipping", "grad_clip")
config["grad_clip"] = config.pop("grad_norm_clipping")
if config.get("parameter_noise", DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning("parameter_noise", "exploration_config={"
"type=ParameterNoise}")
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
raise ValueError("`grad_clip` value must be > 0.0!")
if config["exploration_config"]["type"] == "ParameterNoise":
if config["batch_mode"] != "complete_episodes":

View file

@ -18,22 +18,13 @@ from ray.rllib.utils.annotations import override
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import huber_loss, minimize_and_clip, \
make_tf_callable
from ray.rllib.utils.framework import get_variable, try_import_tf
from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
ACTION_SCOPE = "action"
POLICY_SCOPE = "policy"
POLICY_TARGET_SCOPE = "target_policy"
Q_SCOPE = "critic"
Q_TARGET_SCOPE = "target_critic"
TWIN_Q_SCOPE = "twin_critic"
TWIN_Q_TARGET_SCOPE = "twin_target_critic"
def build_ddpg_models(policy, observation_space, action_space, config):
if policy.config["use_state_preprocessor"]:
@ -126,59 +117,45 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
# Policy network evaluation.
with tf1.variable_scope(POLICY_SCOPE, reuse=True):
# prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS))
policy_t = model.get_policy_output(model_out_t)
# policy_batchnorm_update_ops = list(
# set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
with tf1.variable_scope(POLICY_TARGET_SCOPE):
policy_tp1 = \
policy.target_model.get_policy_output(target_model_out_tp1)
policy_t = model.get_policy_output(model_out_t)
policy_tp1 = \
policy.target_model.get_policy_output(target_model_out_tp1)
# Action outputs.
with tf1.variable_scope(ACTION_SCOPE, reuse=True):
if policy.config["smooth_target_policy"]:
target_noise_clip = policy.config["target_noise_clip"]
clipped_normal_sample = tf.clip_by_value(
tf.random.normal(
tf.shape(policy_tp1),
stddev=policy.config["target_noise"]), -target_noise_clip,
target_noise_clip)
policy_tp1_smoothed = tf.clip_by_value(
policy_tp1 + clipped_normal_sample,
policy.action_space.low * tf.ones_like(policy_tp1),
policy.action_space.high * tf.ones_like(policy_tp1))
else:
# No smoothing, just use deterministic actions.
policy_tp1_smoothed = policy_tp1
if policy.config["smooth_target_policy"]:
target_noise_clip = policy.config["target_noise_clip"]
clipped_normal_sample = tf.clip_by_value(
tf.random.normal(
tf.shape(policy_tp1),
stddev=policy.config["target_noise"]), -target_noise_clip,
target_noise_clip)
policy_tp1_smoothed = tf.clip_by_value(
policy_tp1 + clipped_normal_sample,
policy.action_space.low * tf.ones_like(policy_tp1),
policy.action_space.high * tf.ones_like(policy_tp1))
else:
# No smoothing, just use deterministic actions.
policy_tp1_smoothed = policy_tp1
# Q-net(s) evaluation.
# prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS))
with tf1.variable_scope(Q_SCOPE):
# Q-values for given actions & observations in given current
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
# Q-values for given actions & observations in given current
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
with tf1.variable_scope(Q_SCOPE, reuse=True):
# Q-values for current policy (no noise) in given current state
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
# Q-values for current policy (no noise) in given current state
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
if twin_q:
with tf1.variable_scope(TWIN_Q_SCOPE):
twin_q_t = model.get_twin_q_values(
model_out_t, train_batch[SampleBatch.ACTIONS])
# q_batchnorm_update_ops = list(
# set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
twin_q_t = model.get_twin_q_values(
model_out_t, train_batch[SampleBatch.ACTIONS])
# Target q-net(s) evaluation.
with tf1.variable_scope(Q_TARGET_SCOPE):
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
policy_tp1_smoothed)
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
policy_tp1_smoothed)
if twin_q:
with tf1.variable_scope(TWIN_Q_TARGET_SCOPE):
twin_q_tp1 = policy.target_model.get_twin_q_values(
target_model_out_tp1, policy_tp1_smoothed)
twin_q_tp1 = policy.target_model.get_twin_q_values(
target_model_out_tp1, policy_tp1_smoothed)
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
if twin_q:
@ -220,10 +197,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
if l2_reg is not None:
for var in policy.model.policy_variables():
if "bias" not in var.name:
actor_loss += (l2_reg * tf1.nn.l2_loss(var))
actor_loss += (l2_reg * tf.nn.l2_loss(var))
for var in policy.model.q_variables():
if "bias" not in var.name:
critic_loss += (l2_reg * tf1.nn.l2_loss(var))
critic_loss += (l2_reg * tf.nn.l2_loss(var))
# Model self-supervised losses.
if policy.config["use_state_preprocessor"]:
@ -259,28 +236,18 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
def make_ddpg_optimizers(policy, config):
# Create separate optimizers for actor & critic losses.
policy._actor_optimizer = tf1.train.AdamOptimizer(
learning_rate=config["actor_lr"])
policy._critic_optimizer = tf1.train.AdamOptimizer(
learning_rate=config["critic_lr"])
if tfv == 2 and config["framework"] == "tfe":
policy._actor_optimizer = tf.keras.optimizers.Adam(
learning_rate=config["actor_lr"])
policy._critic_optimizer = tf.keras.optimizers.Adam(
learning_rate=config["critic_lr"])
else:
policy._actor_optimizer = tf1.train.AdamOptimizer(
learning_rate=config["actor_lr"])
policy._critic_optimizer = tf1.train.AdamOptimizer(
learning_rate=config["critic_lr"])
return None
# TFPolicy.__init__(
# self,
# observation_space,
# action_space,
# self.config,
# self.sess,
# #obs_input=self.cur_observations,
# sampled_action=self.output_actions,
# loss=self.actor_loss + self.critic_loss,
# loss_inputs=self.loss_inputs,
# update_ops=q_batchnorm_update_ops + policy_batchnorm_update_ops,
# explore=explore,
# dist_inputs=self._distribution_inputs,
# dist_class=Deterministic,
# timestep=timestep)
def build_apply_op(policy, optimizer, grads_and_vars):
# For policy gradient, update policy net one time v.s.
@ -299,34 +266,44 @@ def build_apply_op(policy, optimizer, grads_and_vars):
critic_op = policy._critic_optimizer.apply_gradients(
policy._critic_grads_and_vars)
# Increment global step & apply ops.
with tf1.control_dependencies([tf1.assign_add(policy.global_step, 1)]):
return tf.group(actor_op, critic_op)
if tfv == 2 and policy.config["framework"] == "tfe":
policy.global_step.assign_add(1)
return tf.no_op()
else:
with tf1.control_dependencies([tf1.assign_add(policy.global_step, 1)]):
return tf.group(actor_op, critic_op)
def gradients_fn(policy, optimizer, loss):
if policy.config["grad_norm_clipping"] is not None:
actor_grads_and_vars = minimize_and_clip(
policy._actor_optimizer,
policy.actor_loss,
var_list=policy.model.policy_variables(),
clip_val=policy.config["grad_norm_clipping"])
critic_grads_and_vars = minimize_and_clip(
policy._critic_optimizer,
policy.critic_loss,
var_list=policy.model.q_variables(),
clip_val=policy.config["grad_norm_clipping"])
if policy.config["framework"] == "tfe":
tape = optimizer.tape
pol_weights = policy.model.policy_variables()
actor_grads_and_vars = list(zip(tape.gradient(
policy.actor_loss, pol_weights), pol_weights))
q_weights = policy.model.q_variables()
critic_grads_and_vars = list(zip(tape.gradient(
policy.critic_loss, q_weights), q_weights))
else:
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
policy.actor_loss, var_list=policy.model.policy_variables())
critic_grads_and_vars = policy._critic_optimizer.compute_gradients(
policy.critic_loss, var_list=policy.model.q_variables())
# Save these for later use in build_apply_op.
policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
if g is not None]
policy._critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars
if g is not None]
# Clip if necessary.
if policy.config["grad_clip"]:
clip_func = tf.clip_by_norm
else:
clip_func = tf.identity
# Save grads and vars for later use in `build_apply_op`.
policy._actor_grads_and_vars = [
(clip_func(g), v) for (g, v) in actor_grads_and_vars if g is not None]
policy._critic_grads_and_vars = [
(clip_func(g), v) for (g, v) in critic_grads_and_vars if g is not None]
grads_and_vars = policy._actor_grads_and_vars + \
policy._critic_grads_and_vars
return grads_and_vars
@ -341,7 +318,10 @@ def build_ddpg_stats(policy, batch):
def before_init_fn(policy, obs_space, action_space, config):
# Create global step for counting the number of update operations.
policy.global_step = tf1.train.get_or_create_global_step()
if tfv == 2 and config["framework"] == "tfe":
policy.global_step = get_variable(0, tf_name="global_step")
else:
policy.global_step = tf1.train.get_or_create_global_step()
class ComputeTDErrorMixin:

View file

@ -24,7 +24,7 @@ class TestApexDDPG(unittest.TestCase):
config["learning_starts"] = 0
config["optimizer"]["num_replay_buffer_shards"] = 1
num_iterations = 1
for _ in framework_iterator(config, ("torch", "tf")):
for _ in framework_iterator(config, frameworks=("tf", "torch")):
plain_config = config.copy()
trainer = apex_ddpg.ApexDDPGTrainer(
config=plain_config, env="Pendulum-v0")

View file

@ -35,15 +35,16 @@ class TestDDPG(unittest.TestCase):
config["learning_starts"] = 0
config["exploration_config"]["random_timesteps"] = 100
num_iterations = 2
num_iterations = 1
# Test against all frameworks.
for _ in framework_iterator(config, ("tf", "torch")):
for _ in framework_iterator(config):
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
for i in range(num_iterations):
results = trainer.train()
print(results)
check_compute_single_action(trainer)
trainer.stop()
def test_ddpg_exploration_and_with_random_prerun(self):
"""Tests DDPG's Exploration (w/ random actions for n timesteps)."""
@ -52,7 +53,7 @@ class TestDDPG(unittest.TestCase):
obs = np.array([0.0, 0.1, -0.1])
# Test against all frameworks.
for _ in framework_iterator(core_config, ("torch", "tf")):
for _ in framework_iterator(core_config):
config = core_config.copy()
# Default OUNoise setup.
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
@ -66,6 +67,7 @@ class TestDDPG(unittest.TestCase):
for _ in range(50):
actions.append(trainer.compute_action(obs))
check(np.std(actions), 0.0, false=True)
trainer.stop()
# Check randomness at beginning.
config["exploration_config"] = {
@ -95,6 +97,7 @@ class TestDDPG(unittest.TestCase):
for _ in range(50):
a = trainer.compute_action(obs, explore=False)
check(a, deterministic_action)
trainer.stop()
def test_ddpg_loss_function(self):
"""Tests DDPG loss function results across all frameworks."""

View file

@ -16,13 +16,14 @@ class TestTD3(unittest.TestCase):
config["num_workers"] = 0 # Run locally.
# Test against all frameworks.
for _ in framework_iterator(config, frameworks=["tf"]):
for _ in framework_iterator(config):
trainer = td3.TD3Trainer(config=config, env="Pendulum-v0")
num_iterations = 2
num_iterations = 1
for i in range(num_iterations):
results = trainer.train()
print(results)
check_compute_single_action(trainer)
trainer.stop()
def test_td3_exploration_and_with_random_prerun(self):
"""Tests TD3's Exploration (w/ random actions for n timesteps)."""
@ -31,7 +32,7 @@ class TestTD3(unittest.TestCase):
obs = np.array([0.0, 0.1, -0.1])
# Test against all frameworks.
for _ in framework_iterator(config, frameworks="tf"):
for _ in framework_iterator(config):
lcl_config = config.copy()
# Default GaussianNoise setup.
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v0")

View file

@ -146,6 +146,9 @@ def validate_config(config):
deprecation_warning("grad_norm_clipping", "grad_clip")
config["grad_clip"] = config.pop("grad_norm_clipping")
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
raise ValueError("`grad_clip` value must be > 0.0!")
# Use same keys as for standard Trainer "model" config.
for model in ["Q_model", "policy_model"]:
if config[model].get("hidden_activation", DEPRECATED_VALUE) != \

View file

@ -15,7 +15,6 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.tf_ops import minimize_and_clip
tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
@ -227,15 +226,14 @@ def sac_actor_critic_loss(policy, model, _, train_batch):
td_error = base_td_error
critic_loss = [
tf1.losses.mean_squared_error(
labels=q_t_selected_target, predictions=q_t_selected, weights=0.5)
0.5 * tf.keras.losses.MSE(
y_true=q_t_selected_target, y_pred=q_t_selected)
]
if policy.config["twin_q"]:
critic_loss.append(
tf1.losses.mean_squared_error(
labels=q_t_selected_target,
predictions=twin_q_t_selected,
weights=0.5))
0.5 * tf.keras.losses.MSE(
y_true=q_t_selected_target,
y_pred=twin_q_t_selected))
# Alpha- and actor losses.
# Note: In the papers, alpha is used directly, here we take the log.
@ -277,63 +275,64 @@ def sac_actor_critic_loss(policy, model, _, train_batch):
return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
def gradients(policy, optimizer, loss):
if policy.config["grad_clip"]:
actor_grads_and_vars = minimize_and_clip(
optimizer, # isn't optimizer not well defined here (which one)?
policy.actor_loss,
var_list=policy.model.policy_variables(),
clip_val=policy.config["grad_clip"])
def gradients_fn(policy, optimizer, loss):
# Eager: Use GradientTape.
if policy.config["framework"] == "tfe":
tape = optimizer.tape
pol_weights = policy.model.policy_variables()
actor_grads_and_vars = list(zip(tape.gradient(
policy.actor_loss, pol_weights), pol_weights))
q_weights = policy.model.q_variables()
if policy.config["twin_q"]:
q_variables = policy.model.q_variables()
half_cutoff = len(q_variables) // 2
critic_grads_and_vars = []
critic_grads_and_vars += minimize_and_clip(
optimizer,
policy.critic_loss[0],
var_list=q_variables[:half_cutoff],
clip_val=policy.config["grad_clip"])
critic_grads_and_vars += minimize_and_clip(
optimizer,
policy.critic_loss[1],
var_list=q_variables[half_cutoff:],
clip_val=policy.config["grad_clip"])
half_cutoff = len(q_weights) // 2
grads_1 = tape.gradient(
policy.critic_loss[0], q_weights[:half_cutoff])
grads_2 = tape.gradient(
policy.critic_loss[1], q_weights[half_cutoff:])
critic_grads_and_vars = \
list(zip(grads_1, q_weights[:half_cutoff])) + \
list(zip(grads_2, q_weights[half_cutoff:]))
else:
critic_grads_and_vars = minimize_and_clip(
optimizer,
policy.critic_loss[0],
var_list=policy.model.q_variables(),
clip_val=policy.config["grad_clip"])
alpha_grads_and_vars = minimize_and_clip(
optimizer,
policy.alpha_loss,
var_list=[policy.model.log_alpha],
clip_val=policy.config["grad_clip"])
critic_grads_and_vars = list(zip(tape.gradient(
policy.critic_loss[0], q_weights), q_weights))
alpha_vars = [policy.model.log_alpha]
alpha_grads_and_vars = list(zip(tape.gradient(
policy.alpha_loss, alpha_vars), alpha_vars))
# Tf1.x: Use optimizer.compute_gradients()
else:
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
policy.actor_loss, var_list=policy.model.policy_variables())
q_weights = policy.model.q_variables()
if policy.config["twin_q"]:
q_variables = policy.model.q_variables()
half_cutoff = len(q_variables) // 2
half_cutoff = len(q_weights) // 2
base_q_optimizer, twin_q_optimizer = policy._critic_optimizer
critic_grads_and_vars = base_q_optimizer.compute_gradients(
policy.critic_loss[0], var_list=q_variables[:half_cutoff]
policy.critic_loss[0], var_list=q_weights[:half_cutoff]
) + twin_q_optimizer.compute_gradients(
policy.critic_loss[1], var_list=q_variables[half_cutoff:])
policy.critic_loss[1], var_list=q_weights[half_cutoff:])
else:
critic_grads_and_vars = policy._critic_optimizer[
0].compute_gradients(
policy.critic_loss[0], var_list=policy.model.q_variables())
policy.critic_loss[0], var_list=q_weights)
alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients(
policy.alpha_loss, var_list=[policy.model.log_alpha])
# save these for later use in build_apply_op
policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
if g is not None]
policy._critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars
if g is not None]
policy._alpha_grads_and_vars = [(g, v) for (g, v) in alpha_grads_and_vars
if g is not None]
# Clip if necessary.
if policy.config["grad_clip"]:
clip_func = tf.clip_by_norm
else:
clip_func = tf.identity
# Save grads and vars for later use in `build_apply_op`.
policy._actor_grads_and_vars = [
(clip_func(g), v) for (g, v) in actor_grads_and_vars if g is not None]
policy._critic_grads_and_vars = [
(clip_func(g), v) for (g, v) in critic_grads_and_vars if g is not None]
policy._alpha_grads_and_vars = [
(clip_func(g), v) for (g, v) in alpha_grads_and_vars if g is not None]
grads_and_vars = (
policy._actor_grads_and_vars + policy._critic_grads_and_vars +
policy._alpha_grads_and_vars)
@ -431,7 +430,7 @@ SACTFPolicy = build_tf_policy(
action_distribution_fn=get_distribution_inputs_and_class,
loss_fn=sac_actor_critic_loss,
stats_fn=stats,
gradients_fn=gradients,
gradients_fn=gradients_fn,
apply_gradients_fn=apply_gradients,
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
mixins=[

View file

@ -54,7 +54,7 @@ class TestSAC(unittest.TestCase):
config["learning_starts"] = 0
config["prioritized_replay"] = True
num_iterations = 1
for _ in framework_iterator(config, ("tf", "torch")):
for _ in framework_iterator(config):
# Test for different env types (discrete w/ and w/o image, + cont).
for env in [
"Pendulum-v0", "MsPacmanNoFrameskip-v4", "CartPole-v0"

View file

@ -3,6 +3,7 @@ from .multi_agent_env import MultiAgentEnv
class PettingZooEnv(MultiAgentEnv):
"""An interface to the PettingZoo MARL environment library.
See: https://github.com/PettingZoo-Team/PettingZoo
Inherits from MultiAgentEnv and exposes a given AEC
@ -31,33 +32,32 @@ class PettingZooEnv(MultiAgentEnv):
>>> env = POMGameEnv(env_creator=prison_v0})
>>> obs = env.reset()
>>> print(obs)
{
"0": [110, 119],
"1": [105, 102],
"2": [99, 95],
}
{
"0": [110, 119],
"1": [105, 102],
"2": [99, 95],
}
>>> obs, rewards, dones, infos = env.step(
action_dict={
"0": 1, "1": 0, "2": 2,
})
>>> print(rewards)
{
"0": 0,
"1": 1,
"2": 0,
}
{
"0": 0,
"1": 1,
"2": 0,
}
>>> print(dones)
{
"0": False, # agent 0 is still running
"1": True, # agent 1 is done
"__all__": False, # the env is not done
}
{
"0": False, # agent 0 is still running
"1": True, # agent 1 is done
"__all__": False, # the env is not done
}
>>> print(infos)
{
"0": {}, # info for agent 0
"1": {}, # info for agent 1
}
{
"0": {}, # info for agent 0
"1": {}, # info for agent 1
}
"""
def __init__(self, env):

View file

@ -103,19 +103,18 @@ class Unity3DEnv(MultiAgentEnv):
Args:
action_dict (dict): Multi-agent action dict with:
keys=agent identifier consisting of
[MLagents behavior name, e.g. "Goalie?team=1"] + "_" +
[Agent index, a unique MLAgent-assigned index per single
agent]
[MLagents behavior name, e.g. "Goalie?team=1"] + "_" +
[Agent index, a unique MLAgent-assigned index per single agent]
Returns:
tuple:
obs: Multi-agent observation dict.
- obs: Multi-agent observation dict.
Only those observations for which to get new actions are
returned.
rewards: Rewards dict matching `obs`.
dones: Done dict with only an __all__ multi-agent entry in it.
__all__=True, if episode is done for all agents.
infos: An (empty) info dict.
- rewards: Rewards dict matching `obs`.
- dones: Done dict with only an __all__ multi-agent entry in
it. __all__=True, if episode is done for all agents.
- infos: An (empty) info dict.
"""
# Set only the required actions (from the DecisionSteps) in Unity3D.

View file

@ -325,16 +325,15 @@ def build_eager_tf_policy(name,
self._is_training = False
self._state_in = state_batches
if tf.executing_eagerly():
n = len(obs_batch)
else:
n = obs_batch.shape[0]
seq_lens = tf.ones(n, dtype=tf.int32)
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
input_dict = {
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
"is_training": tf.constant(False),
}
n = input_dict[SampleBatch.CUR_OBS].shape[0]
seq_lens = tf.ones(n, dtype=tf.int32)
if obs_include_prev_action_reward:
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \

View file

@ -154,14 +154,14 @@ class Policy(metaclass=ABCMeta):
timestep (Optional[int]): The current (sampling) time step.
Keyword Args:
kwargs: forward compatibility placeholder
kwargs: Forward compatibility.
Returns:
Tuple:
actions (TensorType): Single action.
state_outs (List[TensorType]): List of RNN state outputs,
- actions (TensorType): Single action.
- state_outs (List[TensorType]): List of RNN state outputs,
if any.
info (dict): Dictionary of extra features, if any.
- info (dict): Dictionary of extra features, if any.
"""
prev_action_batch = None
prev_reward_batch = None

View file

@ -38,9 +38,6 @@ def do_test_log_likelihood(run,
# Test against all frameworks.
for fw in framework_iterator(config):
if run in [sac.SACTrainer] and fw == "tfe":
continue
trainer = run(config=config, env=env)
policy = trainer.get_policy()
@ -171,7 +168,7 @@ class TestComputeLogLikelihood(unittest.TestCase):
config,
prev_a,
continuous=True,
layer_key=("sequential/action", (0, 2),
layer_key=("sequential/action", (2, 4),
("action_model.action_0.", "action_model.action_out.")),
logp_func=logp_func)

View file

@ -5,8 +5,9 @@ from ray import tune
from ray.rllib.agents.registry import get_agent_class
def check_support(alg, config, test_trace=True):
def check_support(alg, config, test_eager=False, test_trace=True):
config["framework"] = "tfe"
config["log_level"] = "ERROR"
# Test both continuous and discrete actions.
for cont in [True, False]:
if cont and alg in ["DQN", "APEX", "SimpleQ"]:
@ -14,46 +15,31 @@ def check_support(alg, config, test_trace=True):
elif not cont and alg in ["DDPG", "APEX_DDPG", "TD3"]:
continue
print("run={} cont. actions={}".format(alg, cont))
if cont:
config["env"] = "Pendulum-v0"
else:
config["env"] = "CartPole-v0"
a = get_agent_class(alg)
config["log_level"] = "ERROR"
config["eager_tracing"] = False
tune.run(a, config=config, stop={"training_iteration": 1})
if test_eager:
print("tf-eager: alg={} cont.act={}".format(alg, cont))
config["eager_tracing"] = False
tune.run(
a, config=config, stop={"training_iteration": 1}, verbose=1)
if test_trace:
config["eager_tracing"] = True
tune.run(a, config=config, stop={"training_iteration": 1})
print("tf-eager-tracing: alg={} cont.act={}".format(alg, cont))
tune.run(
a, config=config, stop={"training_iteration": 1}, verbose=1)
class TestEagerSupport(unittest.TestCase):
class TestEagerSupportPG(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4, local_mode=True)
ray.init(num_cpus=4)
def tearDown(self):
ray.shutdown()
def test_simple_q(self):
check_support("SimpleQ", {"num_workers": 0, "learning_starts": 0})
def test_dqn(self):
check_support("DQN", {"num_workers": 0, "learning_starts": 0})
# TODO(sven): Add these once DDPG supports eager.
# def test_ddpg(self):
# check_support("DDPG", {"num_workers": 0})
# def test_apex_ddpg(self):
# check_support("APEX_DDPG", {"num_workers": 1})
# def test_td3(self):
# check_support("TD3", {"num_workers": 0})
def test_a2c(self):
check_support("A2C", {"num_workers": 0})
@ -70,7 +56,31 @@ class TestEagerSupport(unittest.TestCase):
check_support("APPO", {"num_workers": 1, "num_gpus": 0})
def test_impala(self):
check_support("IMPALA", {"num_workers": 1, "num_gpus": 0})
check_support(
"IMPALA", {"num_workers": 1, "num_gpus": 0}, test_eager=True)
class TestEagerSupportOffPolicy(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4)
def tearDown(self):
ray.shutdown()
def test_simple_q(self):
check_support("SimpleQ", {"num_workers": 0, "learning_starts": 0})
def test_dqn(self):
check_support("DQN", {"num_workers": 0, "learning_starts": 0})
def test_ddpg(self):
check_support("DDPG", {"num_workers": 0})
# def test_apex_ddpg(self):
# check_support("APEX_DDPG", {"num_workers": 1})
def test_td3(self):
check_support("TD3", {"num_workers": 0})
def test_apex_dqn(self):
check_support(
@ -85,12 +95,15 @@ class TestEagerSupport(unittest.TestCase):
},
})
# TODO(sven): Add this once SAC supports eager.
# def test_sac(self):
# check_support("SAC", {"num_workers": 0, "learning_starts": 0})
def test_sac(self):
check_support("SAC", {"num_workers": 0, "learning_starts": 0})
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
# One can specify the specific TestCase class to run.
# None for all unittest.TestCase classes in this file.
class_ = sys.argv[1] if len(sys.argv) > 0 else None
sys.exit(pytest.main(
["-v", __file__ + ("" if class_ is None else "::" + class_)]))

View file

@ -14,7 +14,9 @@ def check_support_multiagent(alg, config):
register_env("multi_agent_cartpole",
lambda _: MultiAgentCartPole({"num_agents": 2}))
config["log_level"] = "ERROR"
for _ in framework_iterator(config, frameworks=("torch", "tf")):
for fw in framework_iterator(config):
if fw == "tfe" and alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]:
continue
if alg in ["DDPG", "APEX_DDPG", "SAC"]:
a = get_agent_class(alg)(
config=config, env="multi_agent_mountaincar")

View file

@ -88,11 +88,7 @@ def check_support(alg, config, train=True, check_bounds=False, tfe=False):
assert isinstance(a.get_policy().model, FCNetV2)
if train:
a.train()
try:
a.stop()
except Exception as e:
print("Ignoring error stopping agent", e)
pass
a.stop()
print(stat)
frameworks = ("torch", "tf")
@ -108,7 +104,7 @@ def check_support(alg, config, train=True, check_bounds=False, tfe=False):
_do_check(alg, config, a_name, o_name)
class TestSupportedSpaces(unittest.TestCase):
class TestSupportedSpacesPG(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
@ -125,40 +121,6 @@ class TestSupportedSpaces(unittest.TestCase):
check_support("APPO", {"num_gpus": 0, "vtrace": False}, train=False)
check_support("APPO", {"num_gpus": 0, "vtrace": True})
def test_ars(self):
check_support(
"ARS", {
"num_workers": 1,
"noise_size": 1500000,
"num_rollouts": 1,
"rollouts_used": 1
})
def test_ddpg(self):
check_support(
"DDPG", {
"exploration_config": {
"ou_base_scale": 100.0
},
"timesteps_per_iteration": 1,
"buffer_size": 1000,
"use_state_preprocessor": True,
},
check_bounds=True)
def test_dqn(self):
config = {"timesteps_per_iteration": 1, "buffer_size": 1000}
check_support("DQN", config, tfe=True)
def test_es(self):
check_support(
"ES", {
"num_workers": 1,
"noise_size": 1500000,
"episodes_per_batch": 1,
"train_batch_size": 1
})
def test_impala(self):
check_support("IMPALA", {"num_gpus": 0})
@ -176,21 +138,70 @@ class TestSupportedSpaces(unittest.TestCase):
config = {"num_workers": 1, "optimizer": {}}
check_support("PG", config, train=False, check_bounds=True, tfe=True)
class TestSupportedSpacesOffPolicy(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_ddpg(self):
check_support(
"DDPG", {
"exploration_config": {
"ou_base_scale": 100.0
},
"timesteps_per_iteration": 1,
"buffer_size": 1000,
"use_state_preprocessor": True,
},
check_bounds=True)
def test_dqn(self):
config = {"timesteps_per_iteration": 1, "buffer_size": 1000}
check_support("DQN", config, tfe=True)
def test_sac(self):
check_support("SAC", {"buffer_size": 1000}, check_bounds=True)
class TestSupportedSpacesEvolutionAlgos(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_ars(self):
check_support(
"ARS", {
"num_workers": 1,
"noise_size": 1500000,
"num_rollouts": 1,
"rollouts_used": 1
})
def test_es(self):
check_support(
"ES", {
"num_workers": 1,
"noise_size": 1500000,
"episodes_per_batch": 1,
"train_batch_size": 1
})
if __name__ == "__main__":
import pytest
import sys
if len(sys.argv) > 1 and sys.argv[1] == "--smoke":
ACTION_SPACES_TO_TEST = {
"discrete": Discrete(5),
}
OBSERVATION_SPACES_TO_TEST = {
"vector": Box(0.0, 1.0, (5, ), dtype=np.float32),
"atari": Box(0.0, 1.0, (210, 160, 3), dtype=np.float32),
}
sys.exit(pytest.main(["-v", __file__]))
# One can specify the specific TestCase class to run.
# None for all unittest.TestCase classes in this file.
class_ = sys.argv[1] if len(sys.argv) > 0 else None
sys.exit(pytest.main(
["-v", __file__ + ("" if class_ is None else "::" + class_)]))

View file

@ -104,7 +104,7 @@ class GaussianNoise(Exploration):
self.random_exploration.get_tf_exploration_action_op(
action_dist, explore)
stochastic_actions = tf.cond(
pred=ts <= self.random_timesteps,
pred=tf.convert_to_tensor(ts <= self.random_timesteps),
true_fn=lambda: random_actions,
false_fn=lambda: tf.clip_by_value(
deterministic_actions + gaussian_sample,

View file

@ -110,7 +110,7 @@ class OrnsteinUhlenbeckNoise(GaussianNoise):
self.random_exploration.get_tf_exploration_action_op(
action_dist, explore)
exploration_actions = tf.cond(
pred=ts <= self.random_timesteps,
pred=tf.convert_to_tensor(ts <= self.random_timesteps),
true_fn=lambda: random_actions,
false_fn=lambda: stochastic_actions)

View file

@ -28,11 +28,6 @@ def do_test_explorations(run,
# Test all frameworks.
for fw in framework_iterator(core_config):
if fw == "tfe" and run in [
ddpg.DDPGTrainer, sac.SACTrainer, td3.TD3Trainer
]:
continue
print("Agent={}".format(run))
# Test for both the default Agent's exploration AND the `Random`

View file

@ -12,8 +12,7 @@ class TestParameterNoise(unittest.TestCase):
ddpg.DDPGTrainer,
ddpg.DEFAULT_CONFIG,
"Pendulum-v0", {},
np.array([1.0, 0.0, -1.0]),
fws="tf")
np.array([1.0, 0.0, -1.0]))
def test_dqn_parameter_noise(self):
self.do_test_parameter_noise_exploration(
@ -23,18 +22,16 @@ class TestParameterNoise(unittest.TestCase):
"is_slippery": False,
"map_name": "4x4"
},
np.array(0),
fws=("tf", "tfe"))
np.array(0))
def do_test_parameter_noise_exploration(self, trainer_cls, config, env,
env_config, obs, fws):
def do_test_parameter_noise_exploration(
self, trainer_cls, config, env, env_config, obs):
"""Tests, whether an Agent works with ParameterNoise."""
core_config = config.copy()
core_config["num_workers"] = 0 # Run locally.
core_config["env_config"] = env_config
for fw in framework_iterator(core_config, fws):
for fw in framework_iterator(core_config):
config = core_config.copy()
# Algo with ParameterNoise exploration (config["explore"]=True).
@ -44,13 +41,15 @@ class TestParameterNoise(unittest.TestCase):
trainer = trainer_cls(config=config, env=env)
policy = trainer.get_policy()
pol_sess = getattr(policy, "_sess", None)
self.assertFalse(policy.exploration.weights_are_currently_noisy)
noise_before = self._get_current_noise(policy, fw)
check(noise_before, 0.0)
initial_weights = self._get_current_weight(policy, fw)
# Pseudo-start an episode and compare the weights before and after.
policy.exploration.on_episode_start(policy, tf_sess=policy._sess)
policy.exploration.on_episode_start(policy, tf_sess=pol_sess)
self.assertFalse(policy.exploration.weights_are_currently_noisy)
noise_after_ep_start = self._get_current_noise(policy, fw)
weights_after_ep_start = self._get_current_weight(policy, fw)
@ -91,7 +90,7 @@ class TestParameterNoise(unittest.TestCase):
# Pseudo-end the episode and compare weights again.
# Make sure they are the original ones.
policy.exploration.on_episode_end(policy, tf_sess=policy._sess)
policy.exploration.on_episode_end(policy, tf_sess=pol_sess)
weights_after_ep_end = self._get_current_weight(policy, fw)
check(current_weight - noise, weights_after_ep_end, decimals=5)
@ -111,7 +110,7 @@ class TestParameterNoise(unittest.TestCase):
# Pseudo-start an episode and compare the weights before and after
# (they should be the same).
policy.exploration.on_episode_start(policy, tf_sess=policy._sess)
policy.exploration.on_episode_start(policy, tf_sess=pol_sess)
self.assertFalse(policy.exploration.weights_are_currently_noisy)
# Should be the same, as we don't do anything at the beginning of
@ -136,7 +135,7 @@ class TestParameterNoise(unittest.TestCase):
# Pseudo-end the episode and compare weights again.
# Make sure they are the original ones (no noise permanently
# applied throughout the episode).
policy.exploration.on_episode_end(policy, tf_sess=policy._sess)
policy.exploration.on_episode_end(policy, tf_sess=pol_sess)
weights_after_episode_end = self._get_current_weight(policy, fw)
check(initial_weights, weights_after_episode_end)
# Noise should still be the same (re-sampling only happens at
@ -170,7 +169,7 @@ class TestParameterNoise(unittest.TestCase):
# the same action for the same input (parameter noise is
# deterministic).
policy = trainer.get_policy()
policy.exploration.on_episode_start(policy, tf_sess=policy._sess)
policy.exploration.on_episode_start(policy, tf_sess=pol_sess)
a_ = trainer.compute_action(obs)
for _ in range(10):
a = trainer.compute_action(obs, explore=True)