mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
105 lines
3.6 KiB
Python
105 lines
3.6 KiB
Python
import unittest
|
|
|
|
import gym
|
|
import numpy as np
|
|
|
|
from ray.rllib.connectors.action.clip import ClipActionsConnector
|
|
from ray.rllib.connectors.action.immutable import ImmutableActionsConnector
|
|
from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector
|
|
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
|
|
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
|
|
from ray.rllib.connectors.connector import ConnectorContext, get_connector
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.typing import ActionConnectorDataType
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
class TestActionConnector(unittest.TestCase):
|
|
def test_connector_pipeline(self):
|
|
ctx = ConnectorContext()
|
|
connectors = [ConvertToNumpyConnector(ctx)]
|
|
pipeline = ActionConnectorPipeline(ctx, connectors)
|
|
name, params = pipeline.to_config()
|
|
restored = get_connector(ctx, name, params)
|
|
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
|
|
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
|
|
|
|
def test_convert_to_numpy_connector(self):
|
|
ctx = ConnectorContext()
|
|
c = ConvertToNumpyConnector(ctx)
|
|
|
|
name, params = c.to_config()
|
|
|
|
self.assertEqual(name, "ConvertToNumpyConnector")
|
|
|
|
restored = get_connector(ctx, name, params)
|
|
self.assertTrue(isinstance(restored, ConvertToNumpyConnector))
|
|
|
|
action = torch.Tensor([8, 9])
|
|
states = torch.Tensor([[1, 1, 1], [2, 2, 2]])
|
|
ac_data = ActionConnectorDataType(0, 1, (action, states, {}))
|
|
|
|
converted = c(ac_data)
|
|
self.assertTrue(isinstance(converted.output[0], np.ndarray))
|
|
self.assertTrue(isinstance(converted.output[1], np.ndarray))
|
|
|
|
def test_normalize_action_connector(self):
|
|
ctx = ConnectorContext(
|
|
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
|
|
)
|
|
c = NormalizeActionsConnector(ctx)
|
|
|
|
name, params = c.to_config()
|
|
self.assertEqual(name, "NormalizeActionsConnector")
|
|
|
|
restored = get_connector(ctx, name, params)
|
|
self.assertTrue(isinstance(restored, NormalizeActionsConnector))
|
|
|
|
ac_data = ActionConnectorDataType(0, 1, (0.5, [], {}))
|
|
|
|
normalized = c(ac_data)
|
|
self.assertEqual(normalized.output[0], 4.5)
|
|
|
|
def test_clip_action_connector(self):
|
|
ctx = ConnectorContext(
|
|
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
|
|
)
|
|
c = ClipActionsConnector(ctx)
|
|
|
|
name, params = c.to_config()
|
|
self.assertEqual(name, "ClipActionsConnector")
|
|
|
|
restored = get_connector(ctx, name, params)
|
|
self.assertTrue(isinstance(restored, ClipActionsConnector))
|
|
|
|
ac_data = ActionConnectorDataType(0, 1, (8.8, [], {}))
|
|
|
|
clipped = c(ac_data)
|
|
self.assertEqual(clipped.output[0], 6.0)
|
|
|
|
def test_immutable_action_connector(self):
|
|
ctx = ConnectorContext(
|
|
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
|
|
)
|
|
c = ImmutableActionsConnector(ctx)
|
|
|
|
name, params = c.to_config()
|
|
self.assertEqual(name, "ImmutableActionsConnector")
|
|
|
|
restored = get_connector(ctx, name, params)
|
|
self.assertTrue(isinstance(restored, ImmutableActionsConnector))
|
|
|
|
ac_data = ActionConnectorDataType(0, 1, (np.array([8.8]), [], {}))
|
|
|
|
immutable = c(ac_data)
|
|
|
|
with self.assertRaises(ValueError):
|
|
immutable.output[0][0] = 5
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|