[RLlib] Some more bandit cleanup/tests. (#21932)

This commit is contained in:
Sven Mika 2022-01-28 12:03:26 +01:00 committed by GitHub
parent 0ff8bfacec
commit 7fc1683bab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -22,8 +22,6 @@ class TestBandits(unittest.TestCase):
config = { config = {
# Use a simple bandit friendly env. # Use a simple bandit friendly env.
"env": SimpleContextualBandit, "env": SimpleContextualBandit,
# Run locally.
"num_workers": 0,
} }
num_iterations = 5 num_iterations = 5
@ -38,6 +36,26 @@ class TestBandits(unittest.TestCase):
# Force good learning behavior (this is a very simple env). # Force good learning behavior (this is a very simple env).
self.assertTrue(results["episode_reward_mean"] == 10.0) self.assertTrue(results["episode_reward_mean"] == 10.0)
def test_bandit_lin_ucb_compilation(self):
"""Test whether a BanditLinUCBTrainer can be built on all frameworks.
"""
config = {
# Use a simple bandit friendly env.
"env": SimpleContextualBandit,
}
num_iterations = 5
for _ in framework_iterator(config, frameworks="torch"):
trainer = bandit.BanditLinUCBTrainer(config=config)
results = None
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
# Force good learning behavior (this is a very simple env).
self.assertTrue(results["episode_reward_mean"] == 10.0)
def test_deprecated_locations(self): def test_deprecated_locations(self):
"""Tests, whether importing from old contrib dir fails gracefully. """Tests, whether importing from old contrib dir fails gracefully.