ray/rllib/connectors/tests/test_action.py

105 lines
3.6 KiB
Python

import unittest
import gym
import numpy as np
from ray.rllib.connectors.action.clip import ClipActionsConnector
from ray.rllib.connectors.action.immutable import ImmutableActionsConnector
from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
from ray.rllib.connectors.connector import ConnectorContext, get_connector
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ActionConnectorDataType
torch, _ = try_import_torch()
class TestActionConnector(unittest.TestCase):
def test_connector_pipeline(self):
ctx = ConnectorContext()
connectors = [ConvertToNumpyConnector(ctx)]
pipeline = ActionConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
def test_convert_to_numpy_connector(self):
ctx = ConnectorContext()
c = ConvertToNumpyConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "ConvertToNumpyConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ConvertToNumpyConnector))
action = torch.Tensor([8, 9])
states = torch.Tensor([[1, 1, 1], [2, 2, 2]])
ac_data = ActionConnectorDataType(0, 1, (action, states, {}))
converted = c(ac_data)
self.assertTrue(isinstance(converted.output[0], np.ndarray))
self.assertTrue(isinstance(converted.output[1], np.ndarray))
def test_normalize_action_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
)
c = NormalizeActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "NormalizeActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, NormalizeActionsConnector))
ac_data = ActionConnectorDataType(0, 1, (0.5, [], {}))
normalized = c(ac_data)
self.assertEqual(normalized.output[0], 4.5)
def test_clip_action_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
)
c = ClipActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "ClipActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ClipActionsConnector))
ac_data = ActionConnectorDataType(0, 1, (8.8, [], {}))
clipped = c(ac_data)
self.assertEqual(clipped.output[0], 6.0)
def test_immutable_action_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
)
c = ImmutableActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "ImmutableActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ImmutableActionsConnector))
ac_data = ActionConnectorDataType(0, 1, (np.array([8.8]), [], {}))
immutable = c(ac_data)
with self.assertRaises(ValueError):
immutable.output[0][0] = 5
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))