mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Some more bandit
cleanup/tests. (#21932)
This commit is contained in:
parent
0ff8bfacec
commit
7fc1683bab
1 changed files with 20 additions and 2 deletions
|
@ -22,8 +22,6 @@ class TestBandits(unittest.TestCase):
|
|||
config = {
|
||||
# Use a simple bandit friendly env.
|
||||
"env": SimpleContextualBandit,
|
||||
# Run locally.
|
||||
"num_workers": 0,
|
||||
}
|
||||
|
||||
num_iterations = 5
|
||||
|
@ -38,6 +36,26 @@ class TestBandits(unittest.TestCase):
|
|||
# Force good learning behavior (this is a very simple env).
|
||||
self.assertTrue(results["episode_reward_mean"] == 10.0)
|
||||
|
||||
def test_bandit_lin_ucb_compilation(self):
|
||||
"""Test whether a BanditLinUCBTrainer can be built on all frameworks.
|
||||
"""
|
||||
config = {
|
||||
# Use a simple bandit friendly env.
|
||||
"env": SimpleContextualBandit,
|
||||
}
|
||||
|
||||
num_iterations = 5
|
||||
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = bandit.BanditLinUCBTrainer(config=config)
|
||||
results = None
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
# Force good learning behavior (this is a very simple env).
|
||||
self.assertTrue(results["episode_reward_mean"] == 10.0)
|
||||
|
||||
def test_deprecated_locations(self):
|
||||
"""Tests, whether importing from old contrib dir fails gracefully.
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue