[RLlib] Add testing framework_iterator. (#7852)

* Add testing framework_iterator.

* LINT.

* WIP.

* Fix and LINT.

* LINT fix.
This commit is contained in:
Sven Mika 2020-04-03 21:24:25 +02:00 committed by GitHub
parent bb6c675231
commit 1d4823c0ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 323 additions and 303 deletions

View file

@ -3,7 +3,7 @@ import unittest
import ray.rllib.agents.ddpg as ddpg
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.test_utils import check, framework_iterator
tf = try_import_tf()
@ -15,11 +15,7 @@ class TestDDPG(unittest.TestCase):
config["num_workers"] = 0 # Run locally.
# Test against all frameworks.
for fw in ["tf", "eager", "torch"]:
if fw != "tf":
continue
config["eager"] = True if fw == "eager" else False
config["use_pytorch"] = True if fw == "torch" else False
for _ in framework_iterator(config, "tf"):
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
num_iterations = 2
for i in range(num_iterations):
@ -33,12 +29,7 @@ class TestDDPG(unittest.TestCase):
obs = np.array([0.0, 0.1, -0.1])
# Test against all frameworks.
for fw in ["tf", "eager", "torch"]:
if fw != "tf":
continue
config["eager"] = True if fw == "eager" else False
config["use_pytorch"] = True if fw == "torch" else False
for _ in framework_iterator(config, "tf"):
# Default OUNoise setup.
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
# Setting explore=False should always return the same action.

View file

@ -3,7 +3,7 @@ import unittest
import ray.rllib.agents.ddpg.td3 as td3
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.test_utils import check, framework_iterator
tf = try_import_tf()
@ -15,11 +15,7 @@ class TestTD3(unittest.TestCase):
config["num_workers"] = 0 # Run locally.
# Test against all frameworks.
for fw in ["tf", "eager", "torch"]:
if fw != "tf":
continue
config["eager"] = True if fw == "eager" else False
config["use_pytorch"] = True if fw == "torch" else False
for _ in framework_iterator(config, frameworks=["tf"]):
trainer = td3.TD3Trainer(config=config, env="Pendulum-v0")
num_iterations = 2
for i in range(num_iterations):
@ -33,12 +29,7 @@ class TestTD3(unittest.TestCase):
obs = np.array([0.0, 0.1, -0.1])
# Test against all frameworks.
for fw in ["tf", "eager", "torch"]:
if fw != "tf":
continue
config["eager"] = True if fw == "eager" else False
config["use_pytorch"] = True if fw == "torch" else False
for _ in framework_iterator(config, frameworks="tf"):
# Default GaussianNoise setup.
trainer = td3.TD3Trainer(config=config, env="Pendulum-v0")
# Setting explore=False should always return the same action.

View file

@ -77,7 +77,8 @@ class QLoss:
# priority is robust and insensitive to `prioritized_replay_alpha`
self.td_error = tf.nn.softmax_cross_entropy_with_logits(
labels=m, logits=q_logits_t_selected)
self.loss = tf.reduce_mean(self.td_error * importance_weights)
self.loss = tf.reduce_mean(
self.td_error * tf.cast(importance_weights, tf.float32))
self.stats = {
# TODO: better Q stats for dist dqn
"mean_td_error": tf.reduce_mean(self.td_error),

View file

@ -1,10 +1,9 @@
import numpy as np
from tensorflow.python.eager.context import eager_mode
import unittest
import ray.rllib.agents.dqn as dqn
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.test_utils import check, framework_iterator
tf = try_import_tf()
@ -14,42 +13,28 @@ class TestDQN(unittest.TestCase):
"""Test whether a DQNTrainer can be built with both frameworks."""
config = dqn.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
num_iterations = 2
for _ in framework_iterator(config, frameworks=["tf", "eager"]):
# Rainbow.
rainbow_config = config.copy()
rainbow_config["eager"] = False
rainbow_config["num_atoms"] = 10
rainbow_config["noisy"] = True
rainbow_config["double_q"] = True
rainbow_config["dueling"] = True
rainbow_config["n_step"] = 5
trainer = dqn.DQNTrainer(config=rainbow_config, env="CartPole-v0")
num_iterations = 2
for i in range(num_iterations):
results = trainer.train()
print(results)
# tf.
tf_config = config.copy()
tf_config["eager"] = False
trainer = dqn.DQNTrainer(config=tf_config, env="CartPole-v0")
num_iterations = 1
# double-dueling DQN.
plain_config = config.copy()
trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0")
for i in range(num_iterations):
results = trainer.train()
print(results)
# Eager.
eager_config = config.copy()
eager_config["eager"] = True
eager_ctx = eager_mode()
eager_ctx.__enter__()
trainer = dqn.DQNTrainer(config=eager_config, env="CartPole-v0")
num_iterations = 1
for i in range(num_iterations):
results = trainer.train()
print(results)
eager_ctx.__exit__(None, None, None)
def test_dqn_exploration_and_soft_q_config(self):
"""Tests, whether a DQN Agent outputs exploration/softmaxed actions."""
config = dqn.DEFAULT_CONFIG.copy()
@ -58,22 +43,7 @@ class TestDQN(unittest.TestCase):
obs = np.array(0)
# Test against all frameworks.
for fw in ["tf", "eager", "torch"]:
if fw == "torch":
continue
print("framework={}".format(fw))
eager_mode_ctx = None
if fw == "tf":
assert not tf.executing_eagerly()
else:
eager_mode_ctx = eager_mode()
eager_mode_ctx.__enter__()
config["eager"] = fw == "eager"
config["use_pytorch"] = fw == "torch"
for _ in framework_iterator(config, ["tf", "eager"]):
# Default EpsilonGreedy setup.
trainer = dqn.DQNTrainer(config=config, env="FrozenLake-v0")
# Setting explore=False should always return the same action.

View file

@ -14,7 +14,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.numpy import fc
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.test_utils import check, framework_iterator
tf = try_import_tf()
@ -32,16 +32,9 @@ class TestPPO(unittest.TestCase):
"""Test whether a PPOTrainer can be built with both frameworks."""
config = ppo.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
# tf.
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
num_iterations = 2
for i in range(num_iterations):
trainer.train()
# Torch.
config["use_pytorch"] = True
for _ in framework_iterator(config):
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
for i in range(num_iterations):
trainer.train()
@ -81,10 +74,7 @@ class TestPPO(unittest.TestCase):
obs = np.array(0)
# Test against all frameworks.
for fw in ["tf", "eager", "torch"]:
config["eager"] = True if fw == "eager" else False
config["use_pytorch"] = True if fw == "torch" else False
for fw in framework_iterator(config):
# Default Agent should be setup with StochasticSampling.
trainer = ppo.PPOTrainer(config=config, env="FrozenLake-v0")
# explore=False, always expect the same (deterministic) action.
@ -131,7 +121,10 @@ class TestPPO(unittest.TestCase):
[0.9, 1.0, 1.1, 1.2]],
dtype=np.float32),
SampleBatch.ACTIONS: np.array([0, 1, 1]),
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
SampleBatch.REWARDS: np.array([1.0, -1.0, .5], dtype=np.float32),
SampleBatch.PREV_REWARDS: np.array(
[1.0, -1.0, .5], dtype=np.float32),
SampleBatch.DONES: np.array([False, False, True]),
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
SampleBatch.ACTION_DIST_INPUTS: np.array(
@ -140,11 +133,8 @@ class TestPPO(unittest.TestCase):
[-0.5, -0.1, -0.2], dtype=np.float32),
}
for fw in ["tf", "torch"]:
print("framework={}".format(fw))
config["use_pytorch"] = fw == "torch"
config["eager"] = fw == "tf"
for fw, sess in framework_iterator(
config, frameworks=["eager", "tf", "torch"], session=True):
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
@ -152,7 +142,7 @@ class TestPPO(unittest.TestCase):
# to train_batch dict.
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
# [0.50005, -0.505, 0.5]
if fw == "tf":
if fw == "tf" or fw == "eager":
train_batch = postprocess_ppo_gae_tf(policy, train_batch)
else:
train_batch = postprocess_ppo_gae_torch(policy, train_batch)
@ -162,17 +152,18 @@ class TestPPO(unittest.TestCase):
check(train_batch[Postprocessing.VALUE_TARGETS],
[0.50005, -0.505, 0.5])
# Calculate actual PPO loss (results are stored in policy.loss_obj)
# for tf.
if fw == "tf":
# Calculate actual PPO loss.
if fw == "eager":
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
train_batch)
else:
elif fw == "torch":
ppo_surrogate_loss_torch(policy, policy.model,
TorchCategorical, train_batch)
vars = policy.model.variables() if fw == "tf" else \
vars = policy.model.variables() if fw != "torch" else \
list(policy.model.parameters())
if fw == "tf":
vars = policy.get_session().run(vars)
expected_shared_out = fc(train_batch[SampleBatch.CUR_OBS], vars[0],
vars[1])
expected_logits = fc(expected_shared_out, vars[2], vars[3])
@ -181,18 +172,42 @@ class TestPPO(unittest.TestCase):
kl, entropy, pg_loss, vf_loss, overall_loss = \
self._ppo_loss_helper(
policy, policy.model,
Categorical if fw == "tf" else TorchCategorical,
Categorical if fw != "torch" else TorchCategorical,
train_batch,
expected_logits, expected_value_outs
expected_logits, expected_value_outs,
sess=sess
)
if sess:
policy_sess = policy.get_session()
k, e, pl, v, tl = policy_sess.run(
[
policy.loss_obj.mean_kl, policy.loss_obj.mean_entropy,
policy.loss_obj.mean_policy_loss,
policy.loss_obj.mean_vf_loss, policy.loss_obj.loss
],
feed_dict=policy._get_loss_inputs_dict(
train_batch, shuffle=False))
check(k, kl)
check(e, entropy)
check(pl, np.mean(-pg_loss))
check(v, np.mean(vf_loss), decimals=4)
check(tl, overall_loss, decimals=4)
else:
check(policy.loss_obj.mean_kl, kl)
check(policy.loss_obj.mean_entropy, entropy)
check(policy.loss_obj.mean_policy_loss, np.mean(-pg_loss))
check(policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
check(
policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
check(policy.loss_obj.loss, overall_loss, decimals=4)
def _ppo_loss_helper(self, policy, model, dist_class, train_batch, logits,
vf_outs):
def _ppo_loss_helper(self,
policy,
model,
dist_class,
train_batch,
logits,
vf_outs,
sess=None):
"""
Calculates the expected PPO loss (components) given Policy,
Model, distribution, some batch, logits & vf outputs, using numpy.
@ -210,12 +225,20 @@ class TestPPO(unittest.TestCase):
# Entropy-loss component.
entropy = np.mean(dist.entropy().detach().numpy())
else:
if sess:
expected_logp = sess.run(expected_logp)
expected_rho = np.exp(expected_logp -
train_batch[SampleBatch.ACTION_LOGP])
# KL(prev vs current action dist)-loss component.
kl = np.mean(dist_prev.kl(dist))
kl = dist_prev.kl(dist)
if sess:
kl = sess.run(kl)
kl = np.mean(kl)
# Entropy-loss component.
entropy = np.mean(dist.entropy())
entropy = dist.entropy()
if sess:
entropy = sess.run(entropy)
entropy = np.mean(entropy)
# Policy loss component.
pg_loss = np.minimum(
@ -235,9 +258,15 @@ class TestPPO(unittest.TestCase):
vf_loss = np.maximum(vf_loss1, vf_loss2)
# Overall loss.
overall_loss = np.mean(-pg_loss + policy.kl_coeff * kl +
if sess:
policy_sess = policy.get_session()
kl_coeff, entropy_coeff = policy_sess.run(
[policy.kl_coeff, policy.entropy_coeff])
else:
kl_coeff, entropy_coeff = policy.kl_coeff, policy.entropy_coeff
overall_loss = np.mean(-pg_loss + kl_coeff * kl +
policy.config["vf_loss_coeff"] * vf_loss -
policy.entropy_coeff * entropy)
entropy_coeff * entropy)
return kl, entropy, pg_loss, vf_loss, overall_loss

View file

@ -3,6 +3,7 @@ import unittest
import ray
import ray.rllib.agents.sac as sac
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import framework_iterator
tf = try_import_tf()
@ -16,12 +17,7 @@ class TestSAC(unittest.TestCase):
num_iterations = 1
# eager (discrete and cont. actions).
for fw in ["eager", "tf", "torch"]:
print("framework={}".format(fw))
if fw == "torch":
continue
config["eager"] = fw == "eager"
config["use_pytorch"] = fw == "torch"
for _ in framework_iterator(config, ["tf", "eager"]):
for env in [
"CartPole-v0",
"Pendulum-v0",

View file

@ -1,7 +1,6 @@
import numpy as np
from gym.spaces import Box
from scipy.stats import norm
from tensorflow.python.eager.context import eager_mode
import unittest
from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical, \
@ -9,7 +8,7 @@ from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical, \
from ray.rllib.models.torch.torch_action_dist import TorchMultiCategorical
from ray.rllib.utils import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, softmax
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.test_utils import check, framework_iterator
tf = try_import_tf()
torch, _ = try_import_torch()
@ -54,9 +53,7 @@ class TestDistributions(unittest.TestCase):
input_lengths = [num_categories] * num_sub_distributions
inputs_split = np.split(inputs, num_sub_distributions, axis=1)
for fw in ["tf", "eager", "torch"]:
print("framework={}".format(fw))
for fw in framework_iterator():
# Create the correct distribution object.
cls = MultiCategorical if fw != "torch" else TorchMultiCategorical
multi_categorical = cls(inputs, None, input_lengths)
@ -101,7 +98,8 @@ class TestDistributions(unittest.TestCase):
def test_squashed_gaussian(self):
"""Tests the SquashedGaussia ActionDistribution (tf-eager only)."""
with eager_mode():
for fw, sess in framework_iterator(
frameworks=["tf", "eager"], session=True):
input_space = Box(-1.0, 1.0, shape=(200, 10))
low, high = -2.0, 1.0
@ -122,13 +120,17 @@ class TestDistributions(unittest.TestCase):
inputs, {}, low=low, high=high)
expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low
values = squashed_distribution.sample()
if sess:
values = sess.run(values)
self.assertTrue(np.max(values) < high)
self.assertTrue(np.min(values) > low)
check(np.mean(values), expected.mean(), decimals=1)
# Test log-likelihood outputs.
sampled_action_logp = squashed_distribution.sampled_action_logp()
sampled_action_logp = squashed_distribution.logp(values)
if sess:
sampled_action_logp = sess.run(sampled_action_logp)
# Convert to parameters for distr.
stds = np.exp(
np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT))
@ -166,12 +168,15 @@ class TestDistributions(unittest.TestCase):
np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2),
axis=-1)
out = squashed_distribution.logp(values)
check(out, log_prob)
outs = squashed_distribution.logp(values)
if sess:
outs = sess.run(outs)
check(outs, log_prob)
def test_gumbel_softmax(self):
"""Tests the GumbelSoftmax ActionDistribution (tf-eager only)."""
with eager_mode():
for fw, sess in framework_iterator(
frameworks=["tf", "eager"], session=True):
batch_size = 1000
num_categories = 5
input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))
@ -191,6 +196,8 @@ class TestDistributions(unittest.TestCase):
gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
expected_mean = np.mean(np.argmax(inputs, -1)).astype(np.float32)
outs = gumbel_softmax.sample()
if sess:
outs = sess.run(outs)
check(np.mean(np.argmax(outs, -1)), expected_mean, rtol=0.08)

View file

@ -1,6 +1,5 @@
import numpy as np
from scipy.stats import norm
from tensorflow.python.eager.context import eager_mode
import unittest
import ray.rllib.agents.dqn as dqn
@ -8,7 +7,7 @@ import ray.rllib.agents.pg as pg
import ray.rllib.agents.ppo as ppo
import ray.rllib.agents.sac as sac
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.test_utils import check, framework_iterator
from ray.rllib.utils.numpy import one_hot, fc, MIN_LOG_NN_OUTPUT, \
MAX_LOG_NN_OUTPUT
@ -37,20 +36,9 @@ def do_test_log_likelihood(run,
prev_r = None if prev_a is None else np.array(0.0)
# Test against all frameworks.
for fw in ["tf", "eager", "torch"]:
for fw in framework_iterator(config):
if run in [dqn.DQNTrainer, sac.SACTrainer] and fw == "torch":
continue
print("Testing {} with framework={}".format(run, fw))
config["eager"] = fw == "eager"
config["use_pytorch"] = fw == "torch"
eager_ctx = None
if fw == "tf":
assert not tf.executing_eagerly()
elif fw == "eager":
eager_ctx = eager_mode()
eager_ctx.__enter__()
assert tf.executing_eagerly()
trainer = run(config=config, env=env)
@ -114,9 +102,6 @@ def do_test_log_likelihood(run,
prev_reward_batch=np.array([prev_r]))
check(np.exp(logp), expected_prob, atol=0.2)
if eager_ctx:
eager_ctx.__exit__(None, None, None)
class TestComputeLogLikelihood(unittest.TestCase):
def test_dqn(self):

View file

@ -3,7 +3,6 @@
import h5py
import numpy as np
from pathlib import Path
from tensorflow.python.eager.context import eager_mode
import unittest
import ray
@ -13,7 +12,7 @@ from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.test_utils import check, framework_iterator
tf = try_import_tf()
torch, nn = try_import_torch()
@ -131,22 +130,10 @@ def model_import_test(algo, config, env):
agent_cls = get_agent_class(algo)
for fw in ["tf", "torch"]:
print("framework={}".format(fw))
config["use_pytorch"] = fw == "torch"
config["eager"] = fw == "eager"
for fw in framework_iterator(config, ["tf", "torch"]):
config["model"]["custom_model"] = "keras_model" if fw != "torch" else \
"torch_model"
eager_mode_ctx = None
if fw == "eager":
eager_mode_ctx = eager_mode()
eager_mode_ctx.__enter__()
assert tf.executing_eagerly()
elif fw == "tf":
assert not tf.executing_eagerly()
agent = agent_cls(config, env)
def current_weight(agent):
@ -184,9 +171,6 @@ def model_import_test(algo, config, env):
agent.import_model(import_file=import_file)
check(current_weight(agent), weight_after_import)
if eager_mode_ctx:
eager_mode_ctx.__exit__(None, None, None)
class TestModelImport(unittest.TestCase):
def setUp(self):

View file

@ -13,7 +13,7 @@ from ray.rllib.utils.policy_client import PolicyClient
from ray.rllib.utils.policy_server import PolicyServer
from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \
PolynomialSchedule, ExponentialSchedule, ConstantSchedule
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.test_utils import check, framework_iterator
from ray.tune.utils import merge_dicts, deep_update
@ -64,6 +64,7 @@ __all__ = [
"fc",
"force_list",
"force_tuple",
"framework_iterator",
"lstm",
"one_hot",
"relu",

View file

@ -1,6 +1,5 @@
import numpy as np
import sys
from tensorflow.python.eager.context import eager_mode
import unittest
import ray
@ -12,7 +11,7 @@ import ray.rllib.agents.impala as impala
import ray.rllib.agents.pg as pg
import ray.rllib.agents.ppo as ppo
import ray.rllib.agents.sac as sac
from ray.rllib.utils import check, try_import_tf
from ray.rllib.utils import check, framework_iterator, try_import_tf
tf = try_import_tf()
@ -30,7 +29,7 @@ def do_test_explorations(run,
config["num_workers"] = 0
# Test all frameworks.
for fw in ["tf", "eager", "torch"]:
for fw in framework_iterator(config):
if fw == "torch" and \
run in [ddpg.DDPGTrainer, dqn.DQNTrainer, dqn.SimpleQTrainer,
impala.ImpalaTrainer, sac.SACTrainer, td3.TD3Trainer]:
@ -40,9 +39,7 @@ def do_test_explorations(run,
]:
continue
print("Testing {} in framework={}".format(run, fw))
config["eager"] = fw == "eager"
config["use_pytorch"] = fw == "torch"
print("Agent={}".format(run))
# Test for both the default Agent's exploration AND the `Random`
# exploration class.
@ -54,14 +51,6 @@ def do_test_explorations(run,
config["exploration_config"] = {"type": "Random"}
print("exploration={}".format(exploration or "default"))
eager_ctx = None
if fw == "eager":
eager_ctx = eager_mode()
eager_ctx.__enter__()
assert tf.executing_eagerly()
elif fw == "tf":
assert not tf.executing_eagerly()
trainer = run(config=config, env=env)
# Make sure all actions drawn are the same, given same
@ -94,9 +83,6 @@ def do_test_explorations(run,
# Check that the stddev is not 0.0 (values differ).
check(np.std(actions), 0.0, false=True)
if eager_ctx:
eager_ctx.__exit__(None, None, None)
class TestExplorations(unittest.TestCase):
"""

View file

@ -135,7 +135,7 @@ def fc(x, weights, biases=None):
isinstance(weights, torch.Tensor) else weights
biases = biases.detach().numpy() if \
isinstance(biases, torch.Tensor) else biases
if tf:
if tf and tf.executing_eagerly():
x = x.numpy() if isinstance(x, tf.Variable) else x
weights = weights.numpy() if isinstance(weights, tf.Variable) else \
weights

View file

@ -1,9 +1,8 @@
from tensorflow.python.eager.context import eager_mode
import unittest
from ray.rllib.utils.schedules import ConstantSchedule, \
LinearSchedule, ExponentialSchedule, PiecewiseSchedule
from ray.rllib.utils import check, try_import_tf
from ray.rllib.utils import check, framework_iterator, try_import_tf
from ray.rllib.utils.from_config import from_config
tf = try_import_tf()
@ -20,15 +19,10 @@ class TestSchedules(unittest.TestCase):
config = {"value": value}
for fw in ["tf", "torch", None]:
constant = from_config(ConstantSchedule, config, framework=fw)
for t in ts:
out = constant(t)
check(out, value)
# Test eager as well.
with eager_mode():
constant = from_config(ConstantSchedule, config, framework="tf")
for fw in framework_iterator(
frameworks=["tf", "eager", "torch", None]):
fw_ = fw if fw != "eager" else "tf"
constant = from_config(ConstantSchedule, config, framework=fw_)
for t in ts:
out = constant(t)
check(out, value)
@ -36,15 +30,11 @@ class TestSchedules(unittest.TestCase):
def test_linear_schedule(self):
ts = [0, 50, 10, 100, 90, 2, 1, 99, 23]
config = {"schedule_timesteps": 100, "initial_p": 2.1, "final_p": 0.6}
for fw in ["tf", "torch", None]:
linear = from_config(LinearSchedule, config, framework=fw)
for t in ts:
out = linear(t)
check(out, 2.1 - (t / 100) * (2.1 - 0.6), decimals=4)
# Test eager as well.
with eager_mode():
linear = from_config(LinearSchedule, config, framework="tf")
for fw in framework_iterator(
frameworks=["tf", "eager", "torch", None]):
fw_ = fw if fw != "eager" else "tf"
linear = from_config(LinearSchedule, config, framework=fw_)
for t in ts:
out = linear(t)
check(out, 2.1 - (t / 100) * (2.1 - 0.6), decimals=4)
@ -58,17 +48,11 @@ class TestSchedules(unittest.TestCase):
initial_p=2.0,
final_p=0.5,
power=2.0)
for fw in ["tf", "torch", None]:
config["framework"] = fw
polynomial = from_config(config)
for t in ts:
out = polynomial(t)
check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
# Test eager as well.
with eager_mode():
config["framework"] = "tf"
polynomial = from_config(config)
for fw in framework_iterator(
frameworks=["tf", "eager", "torch", None]):
fw_ = fw if fw != "eager" else "tf"
polynomial = from_config(config, framework=fw_)
for t in ts:
out = polynomial(t)
check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
@ -76,17 +60,12 @@ class TestSchedules(unittest.TestCase):
def test_exponential_schedule(self):
ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
config = dict(initial_p=2.0, decay_rate=0.99, schedule_timesteps=100)
for fw in ["tf", "torch", None]:
config["framework"] = fw
exponential = from_config(ExponentialSchedule, config)
for t in ts:
out = exponential(t)
check(out, 2.0 * 0.99**(t / 100), decimals=4)
# Test eager as well.
with eager_mode():
config["framework"] = "tf"
exponential = from_config(ExponentialSchedule, config)
for fw in framework_iterator(
frameworks=["tf", "eager", "torch", None]):
fw_ = fw if fw != "eager" else "tf"
exponential = from_config(
ExponentialSchedule, config, framework=fw_)
for t in ts:
out = exponential(t)
check(out, 2.0 * 0.99**(t / 100), decimals=4)
@ -97,17 +76,11 @@ class TestSchedules(unittest.TestCase):
config = dict(
endpoints=[(0, 50.0), (25, 100.0), (30, 200.0)],
outside_value=14.5)
for fw in ["tf", "torch", None]:
config["framework"] = fw
piecewise = from_config(PiecewiseSchedule, config)
for t, e in zip(ts, expected):
out = piecewise(t)
check(out, e, decimals=4)
# Test eager as well.
with eager_mode():
config["framework"] = "tf"
piecewise = from_config(PiecewiseSchedule, config)
for fw in framework_iterator(
frameworks=["tf", "eager", "torch", None]):
fw_ = fw if fw != "eager" else "tf"
piecewise = from_config(PiecewiseSchedule, config, framework=fw_)
for t, e in zip(ts, expected):
out = piecewise(t)
check(out, e, decimals=4)

View file

@ -1,10 +1,89 @@
import logging
import numpy as np
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf = try_import_tf()
if tf:
eager_mode = None
try:
from tensorflow.python.eager.context import eager_mode
except (ImportError, ModuleNotFoundError):
pass
torch, _ = try_import_torch()
logger = logging.getLogger(__name__)
def framework_iterator(config=None,
frameworks=("tf", "eager", "torch"),
session=False):
"""An generator that allows for looping through n frameworks for testing.
Provides the correct config entries ("use_pytorch" and "eager") as well
as the correct eager/non-eager contexts for tf.
Args:
config (Optional[dict]): An optional config dict to alter in place
depending on the iteration.
frameworks (Tuple[str]): A list/tuple of the frameworks to be tested.
Allowed are: "tf", "eager", and "torch".
session (bool): If True, enter a tf.Session() and yield that as
well in the tf-case (otherwise, yield (fw, None)).
Yields:
str: If enter_session is False:
The current framework ("tf", "eager", "torch") used.
Tuple(str, Union[None,tf.Session]: If enter_session is True:
A tuple of the current fw and the tf.Session if fw="tf".
"""
config = config or {}
frameworks = [frameworks] if isinstance(frameworks, str) else frameworks
for fw in frameworks:
# Skip non-installed frameworks.
if fw == "torch" and not torch:
logger.warning(
"framework_iterator skipping torch (not installed)!")
continue
elif not tf:
logger.warning("framework_iterator skipping {} (tf not "
"installed)!".format(fw))
continue
elif fw == "eager" and not eager_mode:
logger.warning("framework_iterator skipping eager (could not "
"import `eager_mode` from tensorflow.python)!")
continue
assert fw in ["tf", "eager", "torch", None]
# Do we need a test session?
sess = None
if fw == "tf" and session is True:
sess = tf.Session()
sess.__enter__()
print("framework={}".format(fw))
config["eager"] = fw == "eager"
config["use_pytorch"] = fw == "torch"
eager_ctx = None
if fw == "eager":
eager_ctx = eager_mode()
eager_ctx.__enter__()
assert tf.executing_eagerly()
elif fw == "tf":
assert not tf.executing_eagerly()
yield fw if session is False else (fw, sess)
# Exit any context we may have entered.
if eager_ctx:
eager_ctx.__exit__(None, None, None)
elif sess:
sess.__exit__(None, None, None)
def check(x, y, decimals=5, atol=None, rtol=None, false=False):
"""

View file

@ -7,11 +7,9 @@ import unittest
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.test_utils import check, framework_iterator
tf = try_import_tf()
tf.enable_eager_execution()
torch, _ = try_import_torch()
@ -64,79 +62,108 @@ class TestFrameWorkAgnosticComponents(unittest.TestCase):
"""
def test_dummy_components(self):
# Switch on eager for testing purposes.
tf.enable_eager_execution()
# Bazel makes it hard to find files specified in `args` (and `data`).
# Bazel makes it hard to find files specified in `args`
# (and `data`).
# Use the true absolute path.
script_dir = Path(__file__).parent
abs_path = script_dir.absolute()
for fw, sess in framework_iterator(session=True):
fw_ = fw if fw != "eager" else "tf"
# Try to create from an abstract class w/o default constructor.
# Expect None.
test = from_config({
"type": AbstractDummyComponent,
"framework": "torch"
"framework": fw_
})
check(test, None)
# Create a Component via python API (config dict).
component = from_config(
dict(type=DummyComponent, prop_a=1.0, prop_d="non_default"))
dict(
type=DummyComponent,
prop_a=1.0,
prop_d="non_default",
framework=fw_))
check(component.prop_d, "non_default")
# Create a tf Component from json file.
config_file = str(abs_path.joinpath("dummy_config.json"))
component = from_config(config_file)
component = from_config(config_file, framework=fw_)
check(component.prop_c, "default")
check(component.prop_d, 4) # default
check(component.add(3.3).numpy(), 5.3) # prop_b == 2.0
value = component.add(3.3)
if sess:
value = sess.run(value)
check(value, 5.3) # prop_b == 2.0
# Create a torch Component from yaml file.
config_file = str(abs_path.joinpath("dummy_config.yml"))
component = from_config(config_file)
component = from_config(config_file, framework=fw_)
check(component.prop_a, "something else")
check(component.prop_d, 3)
check(component.add(1.2), np.array([2.2])) # prop_b == 1.0
value = component.add(1.2)
if sess:
value = sess.run(value)
check(value, np.array([2.2])) # prop_b == 1.0
# Create tf Component from json-string (e.g. on command line).
component = from_config(
'{"type": "ray.rllib.utils.tests.'
'test_framework_agnostic_components.DummyComponent", '
'"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default"}')
'"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default", '
'"framework": "' + fw_ + '"}')
check(component.prop_a, "A")
check(component.prop_d, 4) # default
check(component.add(-1.1).numpy(), -2.1) # prop_b == -1.0
value = component.add(-1.1)
if sess:
value = sess.run(value)
check(value, -2.1) # prop_b == -1.0
# Test recognizing default module path.
component = from_config(
DummyComponent, '{"type": "NonAbstractChildOfDummyComponent", '
'"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default"}')
'"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default",'
'"framework": "' + fw_ + '"}')
check(component.prop_a, "A")
check(component.prop_d, 4) # default
check(component.add(-1.1).numpy(), -2.1) # prop_b == -1.0
value = component.add(-1.1)
if sess:
value = sess.run(value)
check(value, -2.1) # prop_b == -1.0
# Test recognizing default package path.
scope = None
if sess:
scope = tf.variable_scope("exploration_object")
scope.__enter__()
component = from_config(
Exploration, {
"type": "EpsilonGreedy",
"action_space": Discrete(2),
"framework": "tf",
"framework": fw_,
"num_workers": 0,
"worker_index": 0,
"policy_config": {},
"model": None
})
if scope:
scope.__exit__(None, None, None)
check(component.epsilon_schedule.outside_value, 0.05) # default
# Create torch Component from yaml-string.
component = from_config(
"type: ray.rllib.utils.tests."
"test_framework_agnostic_components.DummyComponent\n"
"prop_a: B\nprop_b: -1.5\nprop_c: non-default\nframework: torch")
"prop_a: B\nprop_b: -1.5\nprop_c: non-default\nframework: "
"{}".format(fw_))
check(component.prop_a, "B")
check(component.prop_d, 4) # default
check(component.add(-5.1), np.array([-6.6])) # prop_b == -1.5
value = component.add(-5.1)
if sess:
value = sess.run(value)
check(value, np.array([-6.6])) # prop_b == -1.5
if __name__ == "__main__":