ray/rllib/connectors/tests/test_action.py

127 lines
4 KiB
Python

import gym
import numpy as np
import unittest
from ray.rllib.connectors.action.clip import ClipActionsConnector
from ray.rllib.connectors.action.lambdas import (
ConvertToNumpyConnector,
UnbatchActionsConnector,
)
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
from ray.rllib.connectors.connector import (
ConnectorContext,
get_connector,
)
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ActionConnectorDataType
torch, _ = try_import_torch()
class TestActionConnector(unittest.TestCase):
def test_connector_pipeline(self):
ctx = ConnectorContext()
connectors = [ConvertToNumpyConnector(ctx)]
pipeline = ActionConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
def test_convert_to_numpy_connector(self):
ctx = ConnectorContext()
c = ConvertToNumpyConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "ConvertToNumpyConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ConvertToNumpyConnector))
action = torch.Tensor([8, 9])
states = torch.Tensor([[1, 1, 1], [2, 2, 2]])
ac_data = ActionConnectorDataType(0, 1, (action, states, {}))
converted = c(ac_data)
self.assertTrue(isinstance(converted.output[0], np.ndarray))
self.assertTrue(isinstance(converted.output[1], np.ndarray))
def test_unbatch_action_connector(self):
ctx = ConnectorContext()
c = UnbatchActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "UnbatchActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, UnbatchActionsConnector))
ac_data = ActionConnectorDataType(
0,
1,
(
{
"a": np.array([1, 2, 3]),
"b": (np.array([4, 5, 6]), np.array([7, 8, 9])),
},
[],
{},
),
)
unbatched = c(ac_data)
actions, _, _ = unbatched.output
self.assertEqual(len(actions), 3)
self.assertEqual(actions[0]["a"], 1)
self.assertTrue((actions[0]["b"] == np.array((4, 7))).all())
self.assertEqual(actions[1]["a"], 2)
self.assertTrue((actions[1]["b"] == np.array((5, 8))).all())
self.assertEqual(actions[2]["a"], 3)
self.assertTrue((actions[2]["b"] == np.array((6, 9))).all())
def test_normalize_action_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
)
c = NormalizeActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "NormalizeActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, NormalizeActionsConnector))
ac_data = ActionConnectorDataType(0, 1, (0.5, [], {}))
normalized = c(ac_data)
self.assertEqual(normalized.output[0], 4.5)
def test_clip_action_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
)
c = ClipActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "ClipActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ClipActionsConnector))
ac_data = ActionConnectorDataType(0, 1, (8.8, [], {}))
clipped = c(ac_data)
self.assertEqual(clipped.output[0], 6.0)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))