mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
![]() |
import unittest
|
||
|
|
||
|
from ray.rllib.connectors.connector import Connector, ConnectorPipeline
|
||
|
|
||
|
|
||
|
class TestConnectorPipeline(unittest.TestCase):
|
||
|
class Tom(Connector):
|
||
|
def to_config():
|
||
|
return "tom"
|
||
|
|
||
|
class Bob(Connector):
|
||
|
def to_config():
|
||
|
return "bob"
|
||
|
|
||
|
class Mary(Connector):
|
||
|
def to_config():
|
||
|
return "mary"
|
||
|
|
||
|
class MockConnectorPipeline(ConnectorPipeline):
|
||
|
def __init__(self, ctx, connectors):
|
||
|
# Real connector pipelines should keep a list of
|
||
|
# Connectors.
|
||
|
# Use strings here for simple unit tests.
|
||
|
self.connectors = connectors
|
||
|
|
||
|
def test_sanity_check(self):
|
||
|
ctx = {}
|
||
|
|
||
|
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||
|
m.insert_before("Bob", self.Mary(ctx))
|
||
|
self.assertEqual(len(m.connectors), 3)
|
||
|
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
|
||
|
|
||
|
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||
|
m.insert_after("Tom", self.Mary(ctx))
|
||
|
self.assertEqual(len(m.connectors), 3)
|
||
|
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
|
||
|
|
||
|
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||
|
m.prepend(self.Mary(ctx))
|
||
|
self.assertEqual(len(m.connectors), 3)
|
||
|
self.assertEqual(m.connectors[0].__class__.__name__, "Mary")
|
||
|
|
||
|
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||
|
m.append(self.Mary(ctx))
|
||
|
self.assertEqual(len(m.connectors), 3)
|
||
|
self.assertEqual(m.connectors[2].__class__.__name__, "Mary")
|
||
|
|
||
|
m.remove("Bob")
|
||
|
self.assertEqual(len(m.connectors), 2)
|
||
|
self.assertEqual(m.connectors[0].__class__.__name__, "Tom")
|
||
|
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import pytest
|
||
|
import sys
|
||
|
|
||
|
sys.exit(pytest.main(["-v", __file__]))
|