[RLlib] Fix broken test_distributions.py (test_categorical) (#12915)

This commit is contained in:
Sven Mika 2020-12-18 00:44:26 +01:00 committed by GitHub
parent d747071dd9
commit 124c8318a8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 9 deletions

View file

@ -1089,13 +1089,12 @@ py_test(
srcs = ["models/tests/test_convtranspose2d_stack.py"]
)
# Failing after the following PR: https://github.com/ray-project/ray/pull/12760.
#py_test(
# name = "test_distributions",
# tags = ["models"],
# size = "medium",
# srcs = ["models/tests/test_distributions.py"]
#)
py_test(
name = "test_distributions",
tags = ["models"],
size = "medium",
srcs = ["models/tests/test_distributions.py"]
)
# --------------------------------------------------------------------
# Evaluation components

View file

@ -87,14 +87,15 @@ class TestDistributions(unittest.TestCase):
batch_size = 10000
num_categories = 4
# Create categorical distribution with n categories.
inputs_space = Box(-1.0, 2.0, shape=(batch_size, num_categories))
inputs_space = Box(
-1.0, 2.0, shape=(batch_size, num_categories), dtype=np.float32)
values_space = Box(
0, num_categories - 1, shape=(batch_size, ), dtype=np.int32)
inputs = inputs_space.sample()
for fw, sess in framework_iterator(
session=True, frameworks=("jax", "tf", "tf2", "torch")):
session=True, frameworks=("tf", "tf2", "torch")):
# Create the correct distribution object.
cls = JAXCategorical if fw == "jax" else Categorical if \
fw != "torch" else TorchCategorical