mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Attempt splitting rollout test to avoid initial timeout (#14999)
This commit is contained in:
parent
ccb0cdaa35
commit
b90cc51c27
2 changed files with 39 additions and 3 deletions
34
rllib/BUILD
34
rllib/BUILD
|
@ -1506,13 +1506,43 @@ py_test(
|
|||
|
||||
# Test train/rollout scripts (w/o confirming rollout performance).
|
||||
py_test(
|
||||
name = "test_rollout_no_learning",
|
||||
name = "test_rollout_no_learning1",
|
||||
main = "tests/test_rollout.py",
|
||||
tags = ["tests_dir", "tests_dir_R"],
|
||||
size = "large",
|
||||
data = ["train.py", "rollout.py"],
|
||||
srcs = ["tests/test_rollout.py"],
|
||||
args = ["TestRolloutSimple"]
|
||||
args = ["TestRolloutSimple1"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_rollout_no_learning2",
|
||||
main = "tests/test_rollout.py",
|
||||
tags = ["tests_dir", "tests_dir_R"],
|
||||
size = "large",
|
||||
data = ["train.py", "rollout.py"],
|
||||
srcs = ["tests/test_rollout.py"],
|
||||
args = ["TestRolloutSimple2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_rollout_no_learning3",
|
||||
main = "tests/test_rollout.py",
|
||||
tags = ["tests_dir", "tests_dir_R"],
|
||||
size = "large",
|
||||
data = ["train.py", "rollout.py"],
|
||||
srcs = ["tests/test_rollout.py"],
|
||||
args = ["TestRolloutSimple3"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_rollout_no_learning4",
|
||||
main = "tests/test_rollout.py",
|
||||
tags = ["tests_dir", "tests_dir_R"],
|
||||
size = "large",
|
||||
data = ["train.py", "rollout.py"],
|
||||
srcs = ["tests/test_rollout.py"],
|
||||
args = ["TestRolloutSimple4"]
|
||||
)
|
||||
|
||||
# Test train/rollout scripts (and confirm `rllib rollout` performance is same
|
||||
|
|
|
@ -226,25 +226,31 @@ def learn_test_multi_agent_plus_rollout(algo):
|
|||
os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
|
||||
|
||||
|
||||
class TestRolloutSimple(unittest.TestCase):
|
||||
class TestRolloutSimple1(unittest.TestCase):
|
||||
def test_a3c(self):
|
||||
rollout_test("A3C")
|
||||
|
||||
def test_ddpg(self):
|
||||
rollout_test("DDPG", env="Pendulum-v0")
|
||||
|
||||
|
||||
class TestRolloutSimple2(unittest.TestCase):
|
||||
def test_dqn(self):
|
||||
rollout_test("DQN")
|
||||
|
||||
def test_es(self):
|
||||
rollout_test("ES")
|
||||
|
||||
|
||||
class TestRolloutSimple3(unittest.TestCase):
|
||||
def test_impala(self):
|
||||
rollout_test("IMPALA", env="CartPole-v0")
|
||||
|
||||
def test_ppo(self):
|
||||
rollout_test("PPO", env="CartPole-v0", test_episode_rollout=True)
|
||||
|
||||
|
||||
class TestRolloutSimple4(unittest.TestCase):
|
||||
def test_sac(self):
|
||||
rollout_test("SAC", env="Pendulum-v0")
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue