[RLlib] Attempt splitting rollout test to avoid initial timeout (#14999)

This commit is contained in:
Eric Liang 2021-03-30 10:20:02 -07:00 committed by GitHub
parent ccb0cdaa35
commit b90cc51c27
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 3 deletions

View file

@ -1506,13 +1506,43 @@ py_test(
# Test train/rollout scripts (w/o confirming rollout performance).
py_test(
name = "test_rollout_no_learning",
name = "test_rollout_no_learning1",
main = "tests/test_rollout.py",
tags = ["tests_dir", "tests_dir_R"],
size = "large",
data = ["train.py", "rollout.py"],
srcs = ["tests/test_rollout.py"],
args = ["TestRolloutSimple"]
args = ["TestRolloutSimple1"]
)
py_test(
name = "test_rollout_no_learning2",
main = "tests/test_rollout.py",
tags = ["tests_dir", "tests_dir_R"],
size = "large",
data = ["train.py", "rollout.py"],
srcs = ["tests/test_rollout.py"],
args = ["TestRolloutSimple2"]
)
py_test(
name = "test_rollout_no_learning3",
main = "tests/test_rollout.py",
tags = ["tests_dir", "tests_dir_R"],
size = "large",
data = ["train.py", "rollout.py"],
srcs = ["tests/test_rollout.py"],
args = ["TestRolloutSimple3"]
)
py_test(
name = "test_rollout_no_learning4",
main = "tests/test_rollout.py",
tags = ["tests_dir", "tests_dir_R"],
size = "large",
data = ["train.py", "rollout.py"],
srcs = ["tests/test_rollout.py"],
args = ["TestRolloutSimple4"]
)
# Test train/rollout scripts (and confirm `rllib rollout` performance is same

View file

@ -226,25 +226,31 @@ def learn_test_multi_agent_plus_rollout(algo):
os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
class TestRolloutSimple(unittest.TestCase):
class TestRolloutSimple1(unittest.TestCase):
def test_a3c(self):
rollout_test("A3C")
def test_ddpg(self):
rollout_test("DDPG", env="Pendulum-v0")
class TestRolloutSimple2(unittest.TestCase):
def test_dqn(self):
rollout_test("DQN")
def test_es(self):
rollout_test("ES")
class TestRolloutSimple3(unittest.TestCase):
def test_impala(self):
rollout_test("IMPALA", env="CartPole-v0")
def test_ppo(self):
rollout_test("PPO", env="CartPole-v0", test_episode_rollout=True)
class TestRolloutSimple4(unittest.TestCase):
def test_sac(self):
rollout_test("SAC", env="Pendulum-v0")