mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Fix broken test_distributions.py (test_categorical) (#12915)
This commit is contained in:
parent
d747071dd9
commit
124c8318a8
2 changed files with 9 additions and 9 deletions
13
rllib/BUILD
13
rllib/BUILD
|
@ -1089,13 +1089,12 @@ py_test(
|
|||
srcs = ["models/tests/test_convtranspose2d_stack.py"]
|
||||
)
|
||||
|
||||
# Failing after the following PR: https://github.com/ray-project/ray/pull/12760.
|
||||
#py_test(
|
||||
# name = "test_distributions",
|
||||
# tags = ["models"],
|
||||
# size = "medium",
|
||||
# srcs = ["models/tests/test_distributions.py"]
|
||||
#)
|
||||
py_test(
|
||||
name = "test_distributions",
|
||||
tags = ["models"],
|
||||
size = "medium",
|
||||
srcs = ["models/tests/test_distributions.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Evaluation components
|
||||
|
|
|
@ -87,14 +87,15 @@ class TestDistributions(unittest.TestCase):
|
|||
batch_size = 10000
|
||||
num_categories = 4
|
||||
# Create categorical distribution with n categories.
|
||||
inputs_space = Box(-1.0, 2.0, shape=(batch_size, num_categories))
|
||||
inputs_space = Box(
|
||||
-1.0, 2.0, shape=(batch_size, num_categories), dtype=np.float32)
|
||||
values_space = Box(
|
||||
0, num_categories - 1, shape=(batch_size, ), dtype=np.int32)
|
||||
|
||||
inputs = inputs_space.sample()
|
||||
|
||||
for fw, sess in framework_iterator(
|
||||
session=True, frameworks=("jax", "tf", "tf2", "torch")):
|
||||
session=True, frameworks=("tf", "tf2", "torch")):
|
||||
# Create the correct distribution object.
|
||||
cls = JAXCategorical if fw == "jax" else Categorical if \
|
||||
fw != "torch" else TorchCategorical
|
||||
|
|
Loading…
Add table
Reference in a new issue