mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Some more bandit
cleanup/tests. (#21932)
This commit is contained in:
parent
0ff8bfacec
commit
7fc1683bab
1 changed files with 20 additions and 2 deletions
|
@ -22,8 +22,6 @@ class TestBandits(unittest.TestCase):
|
||||||
config = {
|
config = {
|
||||||
# Use a simple bandit friendly env.
|
# Use a simple bandit friendly env.
|
||||||
"env": SimpleContextualBandit,
|
"env": SimpleContextualBandit,
|
||||||
# Run locally.
|
|
||||||
"num_workers": 0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
num_iterations = 5
|
num_iterations = 5
|
||||||
|
@ -38,6 +36,26 @@ class TestBandits(unittest.TestCase):
|
||||||
# Force good learning behavior (this is a very simple env).
|
# Force good learning behavior (this is a very simple env).
|
||||||
self.assertTrue(results["episode_reward_mean"] == 10.0)
|
self.assertTrue(results["episode_reward_mean"] == 10.0)
|
||||||
|
|
||||||
|
def test_bandit_lin_ucb_compilation(self):
|
||||||
|
"""Test whether a BanditLinUCBTrainer can be built on all frameworks.
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
# Use a simple bandit friendly env.
|
||||||
|
"env": SimpleContextualBandit,
|
||||||
|
}
|
||||||
|
|
||||||
|
num_iterations = 5
|
||||||
|
|
||||||
|
for _ in framework_iterator(config, frameworks="torch"):
|
||||||
|
trainer = bandit.BanditLinUCBTrainer(config=config)
|
||||||
|
results = None
|
||||||
|
for i in range(num_iterations):
|
||||||
|
results = trainer.train()
|
||||||
|
check_train_results(results)
|
||||||
|
print(results)
|
||||||
|
# Force good learning behavior (this is a very simple env).
|
||||||
|
self.assertTrue(results["episode_reward_mean"] == 10.0)
|
||||||
|
|
||||||
def test_deprecated_locations(self):
|
def test_deprecated_locations(self):
|
||||||
"""Tests, whether importing from old contrib dir fails gracefully.
|
"""Tests, whether importing from old contrib dir fails gracefully.
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue