mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
52 lines
1.7 KiB
Python
52 lines
1.7 KiB
Python
import gym
|
|
import unittest
|
|
|
|
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
|
|
from ray.rllib.models.tf.visionnet import VisionNetwork
|
|
from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVision
|
|
from ray.rllib.utils.framework import try_import_torch, try_import_tf
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
|
|
|
torch, nn = try_import_torch()
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
class TestConv2DDefaultStacks(unittest.TestCase):
|
|
"""Tests our ConvTranspose2D Stack modules/layers."""
|
|
|
|
def test_conv2d_default_stacks(self):
|
|
"""Tests, whether conv2d defaults are available for img obs spaces."""
|
|
action_space = gym.spaces.Discrete(2)
|
|
|
|
shapes = [
|
|
(480, 640, 3),
|
|
(240, 320, 3),
|
|
(96, 96, 3),
|
|
(84, 84, 3),
|
|
(42, 42, 3),
|
|
(10, 10, 3),
|
|
]
|
|
for shape in shapes:
|
|
print(f"shape={shape}")
|
|
obs_space = gym.spaces.Box(-1.0, 1.0, shape=shape)
|
|
for fw in framework_iterator():
|
|
model = ModelCatalog.get_model_v2(
|
|
obs_space, action_space, 2, MODEL_DEFAULTS.copy(), framework=fw
|
|
)
|
|
self.assertTrue(isinstance(model, (VisionNetwork, TorchVision)))
|
|
if fw == "torch":
|
|
output, _ = model(
|
|
{"obs": torch.from_numpy(obs_space.sample()[None])}
|
|
)
|
|
else:
|
|
output, _ = model({"obs": obs_space.sample()[None]})
|
|
# B x [action logits]
|
|
self.assertTrue(output.shape == (1, 2))
|
|
print("ok")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|