mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add testing framework_iterator. (#7852)
* Add testing framework_iterator. * LINT. * WIP. * Fix and LINT. * LINT fix.
This commit is contained in:
parent
bb6c675231
commit
1d4823c0ec
15 changed files with 323 additions and 303 deletions
|
@ -3,7 +3,7 @@ import unittest
|
|||
|
||||
import ray.rllib.agents.ddpg as ddpg
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -15,11 +15,7 @@ class TestDDPG(unittest.TestCase):
|
|||
config["num_workers"] = 0 # Run locally.
|
||||
|
||||
# Test against all frameworks.
|
||||
for fw in ["tf", "eager", "torch"]:
|
||||
if fw != "tf":
|
||||
continue
|
||||
config["eager"] = True if fw == "eager" else False
|
||||
config["use_pytorch"] = True if fw == "torch" else False
|
||||
for _ in framework_iterator(config, "tf"):
|
||||
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
|
||||
num_iterations = 2
|
||||
for i in range(num_iterations):
|
||||
|
@ -33,12 +29,7 @@ class TestDDPG(unittest.TestCase):
|
|||
obs = np.array([0.0, 0.1, -0.1])
|
||||
|
||||
# Test against all frameworks.
|
||||
for fw in ["tf", "eager", "torch"]:
|
||||
if fw != "tf":
|
||||
continue
|
||||
config["eager"] = True if fw == "eager" else False
|
||||
config["use_pytorch"] = True if fw == "torch" else False
|
||||
|
||||
for _ in framework_iterator(config, "tf"):
|
||||
# Default OUNoise setup.
|
||||
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
|
||||
# Setting explore=False should always return the same action.
|
||||
|
|
|
@ -3,7 +3,7 @@ import unittest
|
|||
|
||||
import ray.rllib.agents.ddpg.td3 as td3
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -15,11 +15,7 @@ class TestTD3(unittest.TestCase):
|
|||
config["num_workers"] = 0 # Run locally.
|
||||
|
||||
# Test against all frameworks.
|
||||
for fw in ["tf", "eager", "torch"]:
|
||||
if fw != "tf":
|
||||
continue
|
||||
config["eager"] = True if fw == "eager" else False
|
||||
config["use_pytorch"] = True if fw == "torch" else False
|
||||
for _ in framework_iterator(config, frameworks=["tf"]):
|
||||
trainer = td3.TD3Trainer(config=config, env="Pendulum-v0")
|
||||
num_iterations = 2
|
||||
for i in range(num_iterations):
|
||||
|
@ -33,12 +29,7 @@ class TestTD3(unittest.TestCase):
|
|||
obs = np.array([0.0, 0.1, -0.1])
|
||||
|
||||
# Test against all frameworks.
|
||||
for fw in ["tf", "eager", "torch"]:
|
||||
if fw != "tf":
|
||||
continue
|
||||
config["eager"] = True if fw == "eager" else False
|
||||
config["use_pytorch"] = True if fw == "torch" else False
|
||||
|
||||
for _ in framework_iterator(config, frameworks="tf"):
|
||||
# Default GaussianNoise setup.
|
||||
trainer = td3.TD3Trainer(config=config, env="Pendulum-v0")
|
||||
# Setting explore=False should always return the same action.
|
||||
|
|
|
@ -77,7 +77,8 @@ class QLoss:
|
|||
# priority is robust and insensitive to `prioritized_replay_alpha`
|
||||
self.td_error = tf.nn.softmax_cross_entropy_with_logits(
|
||||
labels=m, logits=q_logits_t_selected)
|
||||
self.loss = tf.reduce_mean(self.td_error * importance_weights)
|
||||
self.loss = tf.reduce_mean(
|
||||
self.td_error * tf.cast(importance_weights, tf.float32))
|
||||
self.stats = {
|
||||
# TODO: better Q stats for dist dqn
|
||||
"mean_td_error": tf.reduce_mean(self.td_error),
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import numpy as np
|
||||
from tensorflow.python.eager.context import eager_mode
|
||||
import unittest
|
||||
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -14,42 +13,28 @@ class TestDQN(unittest.TestCase):
|
|||
"""Test whether a DQNTrainer can be built with both frameworks."""
|
||||
config = dqn.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config, frameworks=["tf", "eager"]):
|
||||
# Rainbow.
|
||||
rainbow_config = config.copy()
|
||||
rainbow_config["eager"] = False
|
||||
rainbow_config["num_atoms"] = 10
|
||||
rainbow_config["noisy"] = True
|
||||
rainbow_config["double_q"] = True
|
||||
rainbow_config["dueling"] = True
|
||||
rainbow_config["n_step"] = 5
|
||||
trainer = dqn.DQNTrainer(config=rainbow_config, env="CartPole-v0")
|
||||
num_iterations = 2
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
|
||||
# tf.
|
||||
tf_config = config.copy()
|
||||
tf_config["eager"] = False
|
||||
trainer = dqn.DQNTrainer(config=tf_config, env="CartPole-v0")
|
||||
num_iterations = 1
|
||||
# double-dueling DQN.
|
||||
plain_config = config.copy()
|
||||
trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
|
||||
# Eager.
|
||||
eager_config = config.copy()
|
||||
eager_config["eager"] = True
|
||||
eager_ctx = eager_mode()
|
||||
eager_ctx.__enter__()
|
||||
trainer = dqn.DQNTrainer(config=eager_config, env="CartPole-v0")
|
||||
num_iterations = 1
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
eager_ctx.__exit__(None, None, None)
|
||||
|
||||
def test_dqn_exploration_and_soft_q_config(self):
|
||||
"""Tests, whether a DQN Agent outputs exploration/softmaxed actions."""
|
||||
config = dqn.DEFAULT_CONFIG.copy()
|
||||
|
@ -58,22 +43,7 @@ class TestDQN(unittest.TestCase):
|
|||
obs = np.array(0)
|
||||
|
||||
# Test against all frameworks.
|
||||
for fw in ["tf", "eager", "torch"]:
|
||||
if fw == "torch":
|
||||
continue
|
||||
|
||||
print("framework={}".format(fw))
|
||||
|
||||
eager_mode_ctx = None
|
||||
if fw == "tf":
|
||||
assert not tf.executing_eagerly()
|
||||
else:
|
||||
eager_mode_ctx = eager_mode()
|
||||
eager_mode_ctx.__enter__()
|
||||
|
||||
config["eager"] = fw == "eager"
|
||||
config["use_pytorch"] = fw == "torch"
|
||||
|
||||
for _ in framework_iterator(config, ["tf", "eager"]):
|
||||
# Default EpsilonGreedy setup.
|
||||
trainer = dqn.DQNTrainer(config=config, env="FrozenLake-v0")
|
||||
# Setting explore=False should always return the same action.
|
||||
|
|
|
@ -14,7 +14,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
|||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.numpy import fc
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -32,16 +32,9 @@ class TestPPO(unittest.TestCase):
|
|||
"""Test whether a PPOTrainer can be built with both frameworks."""
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
|
||||
# tf.
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
|
||||
num_iterations = 2
|
||||
for i in range(num_iterations):
|
||||
trainer.train()
|
||||
|
||||
# Torch.
|
||||
config["use_pytorch"] = True
|
||||
for _ in framework_iterator(config):
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
trainer.train()
|
||||
|
@ -81,10 +74,7 @@ class TestPPO(unittest.TestCase):
|
|||
obs = np.array(0)
|
||||
|
||||
# Test against all frameworks.
|
||||
for fw in ["tf", "eager", "torch"]:
|
||||
config["eager"] = True if fw == "eager" else False
|
||||
config["use_pytorch"] = True if fw == "torch" else False
|
||||
|
||||
for fw in framework_iterator(config):
|
||||
# Default Agent should be setup with StochasticSampling.
|
||||
trainer = ppo.PPOTrainer(config=config, env="FrozenLake-v0")
|
||||
# explore=False, always expect the same (deterministic) action.
|
||||
|
@ -131,7 +121,10 @@ class TestPPO(unittest.TestCase):
|
|||
[0.9, 1.0, 1.1, 1.2]],
|
||||
dtype=np.float32),
|
||||
SampleBatch.ACTIONS: np.array([0, 1, 1]),
|
||||
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
|
||||
SampleBatch.REWARDS: np.array([1.0, -1.0, .5], dtype=np.float32),
|
||||
SampleBatch.PREV_REWARDS: np.array(
|
||||
[1.0, -1.0, .5], dtype=np.float32),
|
||||
SampleBatch.DONES: np.array([False, False, True]),
|
||||
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
|
||||
SampleBatch.ACTION_DIST_INPUTS: np.array(
|
||||
|
@ -140,11 +133,8 @@ class TestPPO(unittest.TestCase):
|
|||
[-0.5, -0.1, -0.2], dtype=np.float32),
|
||||
}
|
||||
|
||||
for fw in ["tf", "torch"]:
|
||||
print("framework={}".format(fw))
|
||||
config["use_pytorch"] = fw == "torch"
|
||||
config["eager"] = fw == "tf"
|
||||
|
||||
for fw, sess in framework_iterator(
|
||||
config, frameworks=["eager", "tf", "torch"], session=True):
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
|
||||
|
@ -152,7 +142,7 @@ class TestPPO(unittest.TestCase):
|
|||
# to train_batch dict.
|
||||
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
|
||||
# [0.50005, -0.505, 0.5]
|
||||
if fw == "tf":
|
||||
if fw == "tf" or fw == "eager":
|
||||
train_batch = postprocess_ppo_gae_tf(policy, train_batch)
|
||||
else:
|
||||
train_batch = postprocess_ppo_gae_torch(policy, train_batch)
|
||||
|
@ -162,17 +152,18 @@ class TestPPO(unittest.TestCase):
|
|||
check(train_batch[Postprocessing.VALUE_TARGETS],
|
||||
[0.50005, -0.505, 0.5])
|
||||
|
||||
# Calculate actual PPO loss (results are stored in policy.loss_obj)
|
||||
# for tf.
|
||||
if fw == "tf":
|
||||
# Calculate actual PPO loss.
|
||||
if fw == "eager":
|
||||
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
|
||||
train_batch)
|
||||
else:
|
||||
elif fw == "torch":
|
||||
ppo_surrogate_loss_torch(policy, policy.model,
|
||||
TorchCategorical, train_batch)
|
||||
|
||||
vars = policy.model.variables() if fw == "tf" else \
|
||||
vars = policy.model.variables() if fw != "torch" else \
|
||||
list(policy.model.parameters())
|
||||
if fw == "tf":
|
||||
vars = policy.get_session().run(vars)
|
||||
expected_shared_out = fc(train_batch[SampleBatch.CUR_OBS], vars[0],
|
||||
vars[1])
|
||||
expected_logits = fc(expected_shared_out, vars[2], vars[3])
|
||||
|
@ -181,18 +172,42 @@ class TestPPO(unittest.TestCase):
|
|||
kl, entropy, pg_loss, vf_loss, overall_loss = \
|
||||
self._ppo_loss_helper(
|
||||
policy, policy.model,
|
||||
Categorical if fw == "tf" else TorchCategorical,
|
||||
Categorical if fw != "torch" else TorchCategorical,
|
||||
train_batch,
|
||||
expected_logits, expected_value_outs
|
||||
expected_logits, expected_value_outs,
|
||||
sess=sess
|
||||
)
|
||||
if sess:
|
||||
policy_sess = policy.get_session()
|
||||
k, e, pl, v, tl = policy_sess.run(
|
||||
[
|
||||
policy.loss_obj.mean_kl, policy.loss_obj.mean_entropy,
|
||||
policy.loss_obj.mean_policy_loss,
|
||||
policy.loss_obj.mean_vf_loss, policy.loss_obj.loss
|
||||
],
|
||||
feed_dict=policy._get_loss_inputs_dict(
|
||||
train_batch, shuffle=False))
|
||||
check(k, kl)
|
||||
check(e, entropy)
|
||||
check(pl, np.mean(-pg_loss))
|
||||
check(v, np.mean(vf_loss), decimals=4)
|
||||
check(tl, overall_loss, decimals=4)
|
||||
else:
|
||||
check(policy.loss_obj.mean_kl, kl)
|
||||
check(policy.loss_obj.mean_entropy, entropy)
|
||||
check(policy.loss_obj.mean_policy_loss, np.mean(-pg_loss))
|
||||
check(policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
|
||||
check(
|
||||
policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
|
||||
check(policy.loss_obj.loss, overall_loss, decimals=4)
|
||||
|
||||
def _ppo_loss_helper(self, policy, model, dist_class, train_batch, logits,
|
||||
vf_outs):
|
||||
def _ppo_loss_helper(self,
|
||||
policy,
|
||||
model,
|
||||
dist_class,
|
||||
train_batch,
|
||||
logits,
|
||||
vf_outs,
|
||||
sess=None):
|
||||
"""
|
||||
Calculates the expected PPO loss (components) given Policy,
|
||||
Model, distribution, some batch, logits & vf outputs, using numpy.
|
||||
|
@ -210,12 +225,20 @@ class TestPPO(unittest.TestCase):
|
|||
# Entropy-loss component.
|
||||
entropy = np.mean(dist.entropy().detach().numpy())
|
||||
else:
|
||||
if sess:
|
||||
expected_logp = sess.run(expected_logp)
|
||||
expected_rho = np.exp(expected_logp -
|
||||
train_batch[SampleBatch.ACTION_LOGP])
|
||||
# KL(prev vs current action dist)-loss component.
|
||||
kl = np.mean(dist_prev.kl(dist))
|
||||
kl = dist_prev.kl(dist)
|
||||
if sess:
|
||||
kl = sess.run(kl)
|
||||
kl = np.mean(kl)
|
||||
# Entropy-loss component.
|
||||
entropy = np.mean(dist.entropy())
|
||||
entropy = dist.entropy()
|
||||
if sess:
|
||||
entropy = sess.run(entropy)
|
||||
entropy = np.mean(entropy)
|
||||
|
||||
# Policy loss component.
|
||||
pg_loss = np.minimum(
|
||||
|
@ -235,9 +258,15 @@ class TestPPO(unittest.TestCase):
|
|||
vf_loss = np.maximum(vf_loss1, vf_loss2)
|
||||
|
||||
# Overall loss.
|
||||
overall_loss = np.mean(-pg_loss + policy.kl_coeff * kl +
|
||||
if sess:
|
||||
policy_sess = policy.get_session()
|
||||
kl_coeff, entropy_coeff = policy_sess.run(
|
||||
[policy.kl_coeff, policy.entropy_coeff])
|
||||
else:
|
||||
kl_coeff, entropy_coeff = policy.kl_coeff, policy.entropy_coeff
|
||||
overall_loss = np.mean(-pg_loss + kl_coeff * kl +
|
||||
policy.config["vf_loss_coeff"] * vf_loss -
|
||||
policy.entropy_coeff * entropy)
|
||||
entropy_coeff * entropy)
|
||||
return kl, entropy, pg_loss, vf_loss, overall_loss
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import unittest
|
|||
import ray
|
||||
import ray.rllib.agents.sac as sac
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -16,12 +17,7 @@ class TestSAC(unittest.TestCase):
|
|||
num_iterations = 1
|
||||
|
||||
# eager (discrete and cont. actions).
|
||||
for fw in ["eager", "tf", "torch"]:
|
||||
print("framework={}".format(fw))
|
||||
if fw == "torch":
|
||||
continue
|
||||
config["eager"] = fw == "eager"
|
||||
config["use_pytorch"] = fw == "torch"
|
||||
for _ in framework_iterator(config, ["tf", "eager"]):
|
||||
for env in [
|
||||
"CartPole-v0",
|
||||
"Pendulum-v0",
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import numpy as np
|
||||
from gym.spaces import Box
|
||||
from scipy.stats import norm
|
||||
from tensorflow.python.eager.context import eager_mode
|
||||
import unittest
|
||||
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical, \
|
||||
|
@ -9,7 +8,7 @@ from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical, \
|
|||
from ray.rllib.models.torch.torch_action_dist import TorchMultiCategorical
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, softmax
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -54,9 +53,7 @@ class TestDistributions(unittest.TestCase):
|
|||
input_lengths = [num_categories] * num_sub_distributions
|
||||
inputs_split = np.split(inputs, num_sub_distributions, axis=1)
|
||||
|
||||
for fw in ["tf", "eager", "torch"]:
|
||||
print("framework={}".format(fw))
|
||||
|
||||
for fw in framework_iterator():
|
||||
# Create the correct distribution object.
|
||||
cls = MultiCategorical if fw != "torch" else TorchMultiCategorical
|
||||
multi_categorical = cls(inputs, None, input_lengths)
|
||||
|
@ -101,7 +98,8 @@ class TestDistributions(unittest.TestCase):
|
|||
|
||||
def test_squashed_gaussian(self):
|
||||
"""Tests the SquashedGaussia ActionDistribution (tf-eager only)."""
|
||||
with eager_mode():
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=["tf", "eager"], session=True):
|
||||
input_space = Box(-1.0, 1.0, shape=(200, 10))
|
||||
low, high = -2.0, 1.0
|
||||
|
||||
|
@ -122,13 +120,17 @@ class TestDistributions(unittest.TestCase):
|
|||
inputs, {}, low=low, high=high)
|
||||
expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low
|
||||
values = squashed_distribution.sample()
|
||||
if sess:
|
||||
values = sess.run(values)
|
||||
self.assertTrue(np.max(values) < high)
|
||||
self.assertTrue(np.min(values) > low)
|
||||
|
||||
check(np.mean(values), expected.mean(), decimals=1)
|
||||
|
||||
# Test log-likelihood outputs.
|
||||
sampled_action_logp = squashed_distribution.sampled_action_logp()
|
||||
sampled_action_logp = squashed_distribution.logp(values)
|
||||
if sess:
|
||||
sampled_action_logp = sess.run(sampled_action_logp)
|
||||
# Convert to parameters for distr.
|
||||
stds = np.exp(
|
||||
np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT))
|
||||
|
@ -166,12 +168,15 @@ class TestDistributions(unittest.TestCase):
|
|||
np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2),
|
||||
axis=-1)
|
||||
|
||||
out = squashed_distribution.logp(values)
|
||||
check(out, log_prob)
|
||||
outs = squashed_distribution.logp(values)
|
||||
if sess:
|
||||
outs = sess.run(outs)
|
||||
check(outs, log_prob)
|
||||
|
||||
def test_gumbel_softmax(self):
|
||||
"""Tests the GumbelSoftmax ActionDistribution (tf-eager only)."""
|
||||
with eager_mode():
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=["tf", "eager"], session=True):
|
||||
batch_size = 1000
|
||||
num_categories = 5
|
||||
input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))
|
||||
|
@ -191,6 +196,8 @@ class TestDistributions(unittest.TestCase):
|
|||
gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
|
||||
expected_mean = np.mean(np.argmax(inputs, -1)).astype(np.float32)
|
||||
outs = gumbel_softmax.sample()
|
||||
if sess:
|
||||
outs = sess.run(outs)
|
||||
check(np.mean(np.argmax(outs, -1)), expected_mean, rtol=0.08)
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
from tensorflow.python.eager.context import eager_mode
|
||||
import unittest
|
||||
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
|
@ -8,7 +7,7 @@ import ray.rllib.agents.pg as pg
|
|||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.agents.sac as sac
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
from ray.rllib.utils.numpy import one_hot, fc, MIN_LOG_NN_OUTPUT, \
|
||||
MAX_LOG_NN_OUTPUT
|
||||
|
||||
|
@ -37,20 +36,9 @@ def do_test_log_likelihood(run,
|
|||
prev_r = None if prev_a is None else np.array(0.0)
|
||||
|
||||
# Test against all frameworks.
|
||||
for fw in ["tf", "eager", "torch"]:
|
||||
for fw in framework_iterator(config):
|
||||
if run in [dqn.DQNTrainer, sac.SACTrainer] and fw == "torch":
|
||||
continue
|
||||
print("Testing {} with framework={}".format(run, fw))
|
||||
config["eager"] = fw == "eager"
|
||||
config["use_pytorch"] = fw == "torch"
|
||||
|
||||
eager_ctx = None
|
||||
if fw == "tf":
|
||||
assert not tf.executing_eagerly()
|
||||
elif fw == "eager":
|
||||
eager_ctx = eager_mode()
|
||||
eager_ctx.__enter__()
|
||||
assert tf.executing_eagerly()
|
||||
|
||||
trainer = run(config=config, env=env)
|
||||
|
||||
|
@ -114,9 +102,6 @@ def do_test_log_likelihood(run,
|
|||
prev_reward_batch=np.array([prev_r]))
|
||||
check(np.exp(logp), expected_prob, atol=0.2)
|
||||
|
||||
if eager_ctx:
|
||||
eager_ctx.__exit__(None, None, None)
|
||||
|
||||
|
||||
class TestComputeLogLikelihood(unittest.TestCase):
|
||||
def test_dqn(self):
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import h5py
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tensorflow.python.eager.context import eager_mode
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
|
@ -13,7 +12,7 @@ from ray.rllib.models.tf.misc import normc_initializer
|
|||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
@ -131,22 +130,10 @@ def model_import_test(algo, config, env):
|
|||
|
||||
agent_cls = get_agent_class(algo)
|
||||
|
||||
for fw in ["tf", "torch"]:
|
||||
print("framework={}".format(fw))
|
||||
|
||||
config["use_pytorch"] = fw == "torch"
|
||||
config["eager"] = fw == "eager"
|
||||
for fw in framework_iterator(config, ["tf", "torch"]):
|
||||
config["model"]["custom_model"] = "keras_model" if fw != "torch" else \
|
||||
"torch_model"
|
||||
|
||||
eager_mode_ctx = None
|
||||
if fw == "eager":
|
||||
eager_mode_ctx = eager_mode()
|
||||
eager_mode_ctx.__enter__()
|
||||
assert tf.executing_eagerly()
|
||||
elif fw == "tf":
|
||||
assert not tf.executing_eagerly()
|
||||
|
||||
agent = agent_cls(config, env)
|
||||
|
||||
def current_weight(agent):
|
||||
|
@ -184,9 +171,6 @@ def model_import_test(algo, config, env):
|
|||
agent.import_model(import_file=import_file)
|
||||
check(current_weight(agent), weight_after_import)
|
||||
|
||||
if eager_mode_ctx:
|
||||
eager_mode_ctx.__exit__(None, None, None)
|
||||
|
||||
|
||||
class TestModelImport(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
|
@ -13,7 +13,7 @@ from ray.rllib.utils.policy_client import PolicyClient
|
|||
from ray.rllib.utils.policy_server import PolicyServer
|
||||
from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \
|
||||
PolynomialSchedule, ExponentialSchedule, ConstantSchedule
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
from ray.tune.utils import merge_dicts, deep_update
|
||||
|
||||
|
||||
|
@ -64,6 +64,7 @@ __all__ = [
|
|||
"fc",
|
||||
"force_list",
|
||||
"force_tuple",
|
||||
"framework_iterator",
|
||||
"lstm",
|
||||
"one_hot",
|
||||
"relu",
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import numpy as np
|
||||
import sys
|
||||
from tensorflow.python.eager.context import eager_mode
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
|
@ -12,7 +11,7 @@ import ray.rllib.agents.impala as impala
|
|||
import ray.rllib.agents.pg as pg
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.agents.sac as sac
|
||||
from ray.rllib.utils import check, try_import_tf
|
||||
from ray.rllib.utils import check, framework_iterator, try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -30,7 +29,7 @@ def do_test_explorations(run,
|
|||
config["num_workers"] = 0
|
||||
|
||||
# Test all frameworks.
|
||||
for fw in ["tf", "eager", "torch"]:
|
||||
for fw in framework_iterator(config):
|
||||
if fw == "torch" and \
|
||||
run in [ddpg.DDPGTrainer, dqn.DQNTrainer, dqn.SimpleQTrainer,
|
||||
impala.ImpalaTrainer, sac.SACTrainer, td3.TD3Trainer]:
|
||||
|
@ -40,9 +39,7 @@ def do_test_explorations(run,
|
|||
]:
|
||||
continue
|
||||
|
||||
print("Testing {} in framework={}".format(run, fw))
|
||||
config["eager"] = fw == "eager"
|
||||
config["use_pytorch"] = fw == "torch"
|
||||
print("Agent={}".format(run))
|
||||
|
||||
# Test for both the default Agent's exploration AND the `Random`
|
||||
# exploration class.
|
||||
|
@ -54,14 +51,6 @@ def do_test_explorations(run,
|
|||
config["exploration_config"] = {"type": "Random"}
|
||||
print("exploration={}".format(exploration or "default"))
|
||||
|
||||
eager_ctx = None
|
||||
if fw == "eager":
|
||||
eager_ctx = eager_mode()
|
||||
eager_ctx.__enter__()
|
||||
assert tf.executing_eagerly()
|
||||
elif fw == "tf":
|
||||
assert not tf.executing_eagerly()
|
||||
|
||||
trainer = run(config=config, env=env)
|
||||
|
||||
# Make sure all actions drawn are the same, given same
|
||||
|
@ -94,9 +83,6 @@ def do_test_explorations(run,
|
|||
# Check that the stddev is not 0.0 (values differ).
|
||||
check(np.std(actions), 0.0, false=True)
|
||||
|
||||
if eager_ctx:
|
||||
eager_ctx.__exit__(None, None, None)
|
||||
|
||||
|
||||
class TestExplorations(unittest.TestCase):
|
||||
"""
|
||||
|
|
|
@ -135,7 +135,7 @@ def fc(x, weights, biases=None):
|
|||
isinstance(weights, torch.Tensor) else weights
|
||||
biases = biases.detach().numpy() if \
|
||||
isinstance(biases, torch.Tensor) else biases
|
||||
if tf:
|
||||
if tf and tf.executing_eagerly():
|
||||
x = x.numpy() if isinstance(x, tf.Variable) else x
|
||||
weights = weights.numpy() if isinstance(weights, tf.Variable) else \
|
||||
weights
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from tensorflow.python.eager.context import eager_mode
|
||||
import unittest
|
||||
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, \
|
||||
LinearSchedule, ExponentialSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils import check, try_import_tf
|
||||
from ray.rllib.utils import check, framework_iterator, try_import_tf
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
|
||||
tf = try_import_tf()
|
||||
|
@ -20,15 +19,10 @@ class TestSchedules(unittest.TestCase):
|
|||
|
||||
config = {"value": value}
|
||||
|
||||
for fw in ["tf", "torch", None]:
|
||||
constant = from_config(ConstantSchedule, config, framework=fw)
|
||||
for t in ts:
|
||||
out = constant(t)
|
||||
check(out, value)
|
||||
|
||||
# Test eager as well.
|
||||
with eager_mode():
|
||||
constant = from_config(ConstantSchedule, config, framework="tf")
|
||||
for fw in framework_iterator(
|
||||
frameworks=["tf", "eager", "torch", None]):
|
||||
fw_ = fw if fw != "eager" else "tf"
|
||||
constant = from_config(ConstantSchedule, config, framework=fw_)
|
||||
for t in ts:
|
||||
out = constant(t)
|
||||
check(out, value)
|
||||
|
@ -36,15 +30,11 @@ class TestSchedules(unittest.TestCase):
|
|||
def test_linear_schedule(self):
|
||||
ts = [0, 50, 10, 100, 90, 2, 1, 99, 23]
|
||||
config = {"schedule_timesteps": 100, "initial_p": 2.1, "final_p": 0.6}
|
||||
for fw in ["tf", "torch", None]:
|
||||
linear = from_config(LinearSchedule, config, framework=fw)
|
||||
for t in ts:
|
||||
out = linear(t)
|
||||
check(out, 2.1 - (t / 100) * (2.1 - 0.6), decimals=4)
|
||||
|
||||
# Test eager as well.
|
||||
with eager_mode():
|
||||
linear = from_config(LinearSchedule, config, framework="tf")
|
||||
for fw in framework_iterator(
|
||||
frameworks=["tf", "eager", "torch", None]):
|
||||
fw_ = fw if fw != "eager" else "tf"
|
||||
linear = from_config(LinearSchedule, config, framework=fw_)
|
||||
for t in ts:
|
||||
out = linear(t)
|
||||
check(out, 2.1 - (t / 100) * (2.1 - 0.6), decimals=4)
|
||||
|
@ -58,17 +48,11 @@ class TestSchedules(unittest.TestCase):
|
|||
initial_p=2.0,
|
||||
final_p=0.5,
|
||||
power=2.0)
|
||||
for fw in ["tf", "torch", None]:
|
||||
config["framework"] = fw
|
||||
polynomial = from_config(config)
|
||||
for t in ts:
|
||||
out = polynomial(t)
|
||||
check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
|
||||
|
||||
# Test eager as well.
|
||||
with eager_mode():
|
||||
config["framework"] = "tf"
|
||||
polynomial = from_config(config)
|
||||
for fw in framework_iterator(
|
||||
frameworks=["tf", "eager", "torch", None]):
|
||||
fw_ = fw if fw != "eager" else "tf"
|
||||
polynomial = from_config(config, framework=fw_)
|
||||
for t in ts:
|
||||
out = polynomial(t)
|
||||
check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
|
||||
|
@ -76,17 +60,12 @@ class TestSchedules(unittest.TestCase):
|
|||
def test_exponential_schedule(self):
|
||||
ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
|
||||
config = dict(initial_p=2.0, decay_rate=0.99, schedule_timesteps=100)
|
||||
for fw in ["tf", "torch", None]:
|
||||
config["framework"] = fw
|
||||
exponential = from_config(ExponentialSchedule, config)
|
||||
for t in ts:
|
||||
out = exponential(t)
|
||||
check(out, 2.0 * 0.99**(t / 100), decimals=4)
|
||||
|
||||
# Test eager as well.
|
||||
with eager_mode():
|
||||
config["framework"] = "tf"
|
||||
exponential = from_config(ExponentialSchedule, config)
|
||||
for fw in framework_iterator(
|
||||
frameworks=["tf", "eager", "torch", None]):
|
||||
fw_ = fw if fw != "eager" else "tf"
|
||||
exponential = from_config(
|
||||
ExponentialSchedule, config, framework=fw_)
|
||||
for t in ts:
|
||||
out = exponential(t)
|
||||
check(out, 2.0 * 0.99**(t / 100), decimals=4)
|
||||
|
@ -97,17 +76,11 @@ class TestSchedules(unittest.TestCase):
|
|||
config = dict(
|
||||
endpoints=[(0, 50.0), (25, 100.0), (30, 200.0)],
|
||||
outside_value=14.5)
|
||||
for fw in ["tf", "torch", None]:
|
||||
config["framework"] = fw
|
||||
piecewise = from_config(PiecewiseSchedule, config)
|
||||
for t, e in zip(ts, expected):
|
||||
out = piecewise(t)
|
||||
check(out, e, decimals=4)
|
||||
|
||||
# Test eager as well.
|
||||
with eager_mode():
|
||||
config["framework"] = "tf"
|
||||
piecewise = from_config(PiecewiseSchedule, config)
|
||||
for fw in framework_iterator(
|
||||
frameworks=["tf", "eager", "torch", None]):
|
||||
fw_ = fw if fw != "eager" else "tf"
|
||||
piecewise = from_config(PiecewiseSchedule, config, framework=fw_)
|
||||
for t, e in zip(ts, expected):
|
||||
out = piecewise(t)
|
||||
check(out, e, decimals=4)
|
||||
|
|
|
@ -1,10 +1,89 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
if tf:
|
||||
eager_mode = None
|
||||
try:
|
||||
from tensorflow.python.eager.context import eager_mode
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def framework_iterator(config=None,
|
||||
frameworks=("tf", "eager", "torch"),
|
||||
session=False):
|
||||
"""An generator that allows for looping through n frameworks for testing.
|
||||
|
||||
Provides the correct config entries ("use_pytorch" and "eager") as well
|
||||
as the correct eager/non-eager contexts for tf.
|
||||
|
||||
Args:
|
||||
config (Optional[dict]): An optional config dict to alter in place
|
||||
depending on the iteration.
|
||||
frameworks (Tuple[str]): A list/tuple of the frameworks to be tested.
|
||||
Allowed are: "tf", "eager", and "torch".
|
||||
session (bool): If True, enter a tf.Session() and yield that as
|
||||
well in the tf-case (otherwise, yield (fw, None)).
|
||||
|
||||
Yields:
|
||||
str: If enter_session is False:
|
||||
The current framework ("tf", "eager", "torch") used.
|
||||
Tuple(str, Union[None,tf.Session]: If enter_session is True:
|
||||
A tuple of the current fw and the tf.Session if fw="tf".
|
||||
"""
|
||||
config = config or {}
|
||||
frameworks = [frameworks] if isinstance(frameworks, str) else frameworks
|
||||
|
||||
for fw in frameworks:
|
||||
# Skip non-installed frameworks.
|
||||
if fw == "torch" and not torch:
|
||||
logger.warning(
|
||||
"framework_iterator skipping torch (not installed)!")
|
||||
continue
|
||||
elif not tf:
|
||||
logger.warning("framework_iterator skipping {} (tf not "
|
||||
"installed)!".format(fw))
|
||||
continue
|
||||
elif fw == "eager" and not eager_mode:
|
||||
logger.warning("framework_iterator skipping eager (could not "
|
||||
"import `eager_mode` from tensorflow.python)!")
|
||||
continue
|
||||
assert fw in ["tf", "eager", "torch", None]
|
||||
|
||||
# Do we need a test session?
|
||||
sess = None
|
||||
if fw == "tf" and session is True:
|
||||
sess = tf.Session()
|
||||
sess.__enter__()
|
||||
|
||||
print("framework={}".format(fw))
|
||||
|
||||
config["eager"] = fw == "eager"
|
||||
config["use_pytorch"] = fw == "torch"
|
||||
|
||||
eager_ctx = None
|
||||
if fw == "eager":
|
||||
eager_ctx = eager_mode()
|
||||
eager_ctx.__enter__()
|
||||
assert tf.executing_eagerly()
|
||||
elif fw == "tf":
|
||||
assert not tf.executing_eagerly()
|
||||
|
||||
yield fw if session is False else (fw, sess)
|
||||
|
||||
# Exit any context we may have entered.
|
||||
if eager_ctx:
|
||||
eager_ctx.__exit__(None, None, None)
|
||||
elif sess:
|
||||
sess.__exit__(None, None, None)
|
||||
|
||||
|
||||
def check(x, y, decimals=5, atol=None, rtol=None, false=False):
|
||||
"""
|
||||
|
|
|
@ -7,11 +7,9 @@ import unittest
|
|||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
tf.enable_eager_execution()
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
|
@ -64,79 +62,108 @@ class TestFrameWorkAgnosticComponents(unittest.TestCase):
|
|||
"""
|
||||
|
||||
def test_dummy_components(self):
|
||||
# Switch on eager for testing purposes.
|
||||
tf.enable_eager_execution()
|
||||
|
||||
# Bazel makes it hard to find files specified in `args` (and `data`).
|
||||
# Bazel makes it hard to find files specified in `args`
|
||||
# (and `data`).
|
||||
# Use the true absolute path.
|
||||
script_dir = Path(__file__).parent
|
||||
abs_path = script_dir.absolute()
|
||||
|
||||
for fw, sess in framework_iterator(session=True):
|
||||
fw_ = fw if fw != "eager" else "tf"
|
||||
# Try to create from an abstract class w/o default constructor.
|
||||
# Expect None.
|
||||
test = from_config({
|
||||
"type": AbstractDummyComponent,
|
||||
"framework": "torch"
|
||||
"framework": fw_
|
||||
})
|
||||
check(test, None)
|
||||
|
||||
# Create a Component via python API (config dict).
|
||||
component = from_config(
|
||||
dict(type=DummyComponent, prop_a=1.0, prop_d="non_default"))
|
||||
dict(
|
||||
type=DummyComponent,
|
||||
prop_a=1.0,
|
||||
prop_d="non_default",
|
||||
framework=fw_))
|
||||
check(component.prop_d, "non_default")
|
||||
|
||||
# Create a tf Component from json file.
|
||||
config_file = str(abs_path.joinpath("dummy_config.json"))
|
||||
component = from_config(config_file)
|
||||
component = from_config(config_file, framework=fw_)
|
||||
check(component.prop_c, "default")
|
||||
check(component.prop_d, 4) # default
|
||||
check(component.add(3.3).numpy(), 5.3) # prop_b == 2.0
|
||||
value = component.add(3.3)
|
||||
if sess:
|
||||
value = sess.run(value)
|
||||
check(value, 5.3) # prop_b == 2.0
|
||||
|
||||
# Create a torch Component from yaml file.
|
||||
config_file = str(abs_path.joinpath("dummy_config.yml"))
|
||||
component = from_config(config_file)
|
||||
component = from_config(config_file, framework=fw_)
|
||||
check(component.prop_a, "something else")
|
||||
check(component.prop_d, 3)
|
||||
check(component.add(1.2), np.array([2.2])) # prop_b == 1.0
|
||||
value = component.add(1.2)
|
||||
if sess:
|
||||
value = sess.run(value)
|
||||
check(value, np.array([2.2])) # prop_b == 1.0
|
||||
|
||||
# Create tf Component from json-string (e.g. on command line).
|
||||
component = from_config(
|
||||
'{"type": "ray.rllib.utils.tests.'
|
||||
'test_framework_agnostic_components.DummyComponent", '
|
||||
'"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default"}')
|
||||
'"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default", '
|
||||
'"framework": "' + fw_ + '"}')
|
||||
check(component.prop_a, "A")
|
||||
check(component.prop_d, 4) # default
|
||||
check(component.add(-1.1).numpy(), -2.1) # prop_b == -1.0
|
||||
value = component.add(-1.1)
|
||||
if sess:
|
||||
value = sess.run(value)
|
||||
check(value, -2.1) # prop_b == -1.0
|
||||
|
||||
# Test recognizing default module path.
|
||||
component = from_config(
|
||||
DummyComponent, '{"type": "NonAbstractChildOfDummyComponent", '
|
||||
'"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default"}')
|
||||
'"prop_a": "A", "prop_b": -1.0, "prop_c": "non-default",'
|
||||
'"framework": "' + fw_ + '"}')
|
||||
check(component.prop_a, "A")
|
||||
check(component.prop_d, 4) # default
|
||||
check(component.add(-1.1).numpy(), -2.1) # prop_b == -1.0
|
||||
value = component.add(-1.1)
|
||||
if sess:
|
||||
value = sess.run(value)
|
||||
check(value, -2.1) # prop_b == -1.0
|
||||
|
||||
# Test recognizing default package path.
|
||||
scope = None
|
||||
if sess:
|
||||
scope = tf.variable_scope("exploration_object")
|
||||
scope.__enter__()
|
||||
component = from_config(
|
||||
Exploration, {
|
||||
"type": "EpsilonGreedy",
|
||||
"action_space": Discrete(2),
|
||||
"framework": "tf",
|
||||
"framework": fw_,
|
||||
"num_workers": 0,
|
||||
"worker_index": 0,
|
||||
"policy_config": {},
|
||||
"model": None
|
||||
})
|
||||
if scope:
|
||||
scope.__exit__(None, None, None)
|
||||
check(component.epsilon_schedule.outside_value, 0.05) # default
|
||||
|
||||
# Create torch Component from yaml-string.
|
||||
component = from_config(
|
||||
"type: ray.rllib.utils.tests."
|
||||
"test_framework_agnostic_components.DummyComponent\n"
|
||||
"prop_a: B\nprop_b: -1.5\nprop_c: non-default\nframework: torch")
|
||||
"prop_a: B\nprop_b: -1.5\nprop_c: non-default\nframework: "
|
||||
"{}".format(fw_))
|
||||
check(component.prop_a, "B")
|
||||
check(component.prop_d, 4) # default
|
||||
check(component.add(-5.1), np.array([-6.6])) # prop_b == -1.5
|
||||
value = component.add(-5.1)
|
||||
if sess:
|
||||
value = sess.run(value)
|
||||
check(value, np.array([-6.6])) # prop_b == -1.5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Reference in a new issue