[RLlib] Fix flakey custom_fast_model_torch/tf tests. (#15330)

This commit is contained in:
Sven Mika 2021-04-15 16:10:29 +02:00 committed by GitHub
parent 981fa5829a
commit 45d6560759
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 12 deletions

View file

@ -1889,20 +1889,19 @@ py_test(
name = "examples/custom_fast_model_tf", name = "examples/custom_fast_model_tf",
main = "examples/custom_fast_model.py", main = "examples/custom_fast_model.py",
tags = ["examples", "examples_C"], tags = ["examples", "examples_C"],
size = "small", size = "medium",
srcs = ["examples/custom_fast_model.py"], srcs = ["examples/custom_fast_model.py"],
args = ["--stop-iters=1", "--num-cpus=4"] args = ["--stop-iters=1"]
) )
# Skip because it is consistently failing in the master. py_test(
# py_test( name = "examples/custom_fast_model_torch",
# name = "examples/custom_fast_model_torch", main = "examples/custom_fast_model.py",
# main = "examples/custom_fast_model.py", tags = ["examples", "examples_C"],
# tags = ["examples", "examples_C"], size = "medium",
# size = "small", srcs = ["examples/custom_fast_model.py"],
# srcs = ["examples/custom_fast_model.py"], args = ["--stop-iters=1", "--torch"]
# args = ["--torch", "--stop-iters=1", "--num-cpus=4"] )
# )
py_test( py_test(
name = "examples/custom_keras_model_a2c", name = "examples/custom_keras_model_a2c",

View file

@ -15,7 +15,7 @@ from ray.rllib.examples.models.fast_model import FastModel, TorchFastModel
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--num-cpus", type=int, default=2) parser.add_argument("--num-cpus", type=int, default=4)
parser.add_argument("--torch", action="store_true") parser.add_argument("--torch", action="store_true")
parser.add_argument("--stop-iters", type=int, default=200) parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-timesteps", type=int, default=100000) parser.add_argument("--stop-timesteps", type=int, default=100000)