mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
33 lines
664 B
Python
33 lines
664 B
Python
![]() |
import gym
|
||
|
import unittest
|
||
|
|
||
|
from ray.rllib.env.vector_env import VectorEnv
|
||
|
|
||
|
|
||
|
class Info(dict):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class MockEnvDictSubclass(gym.Env):
|
||
|
def __init__(self):
|
||
|
self.observation_space = gym.spaces.Discrete(1)
|
||
|
self.action_space = gym.spaces.Discrete(2)
|
||
|
|
||
|
def reset(self):
|
||
|
return 0
|
||
|
|
||
|
def step(self, action):
|
||
|
return 0, 1, True, Info()
|
||
|
|
||
|
|
||
|
class TestExternalEnv(unittest.TestCase):
|
||
|
def test_vector_step(self):
|
||
|
env = VectorEnv.wrap(lambda _: MockEnvDictSubclass(), num_envs=3)
|
||
|
env.vector_step([0] * 3)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import pytest
|
||
|
import sys
|
||
|
sys.exit(pytest.main(["-v", __file__]))
|