[RLlib] Rename connector's from/to config methods to better reflect that they include state. (#27806)

This commit is contained in:
Artur Niederfahrenhorst 2022-08-29 14:37:21 +02:00 committed by GitHub
parent 328e6ac2f4
commit 2ce80d8163
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 50 additions and 52 deletions

View file

@ -29,11 +29,11 @@ class ClipActionsConnector(ActionConnector):
(clip_action(actions, self._action_space_struct), states, fetches), (clip_action(actions, self._action_space_struct), states, fetches),
) )
def to_config(self): def to_state(self):
return ClipActionsConnector.__name__, None return ClipActionsConnector.__name__, None
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
return ClipActionsConnector(ctx) return ClipActionsConnector(ctx)

View file

@ -28,11 +28,11 @@ class ImmutableActionsConnector(ActionConnector):
(actions, states, fetches), (actions, states, fetches),
) )
def to_config(self): def to_state(self):
return ImmutableActionsConnector.__name__, None return ImmutableActionsConnector.__name__, None
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
return ImmutableActionsConnector(ctx) return ImmutableActionsConnector(ctx)

View file

@ -47,11 +47,11 @@ def register_lambda_action_connector(
fn(actions, states, fetches), fn(actions, states, fetches),
) )
def to_config(self): def to_state(self):
return name, None return name, None
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
return LambdaActionConnector(ctx) return LambdaActionConnector(ctx)
LambdaActionConnector.__name__ = name LambdaActionConnector.__name__ = name

View file

@ -32,11 +32,11 @@ class NormalizeActionsConnector(ActionConnector):
(unsquash_action(actions, self._action_space_struct), states, fetches), (unsquash_action(actions, self._action_space_struct), states, fetches),
) )
def to_config(self): def to_state(self):
return NormalizeActionsConnector.__name__, None return NormalizeActionsConnector.__name__, None
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
return NormalizeActionsConnector(ctx) return NormalizeActionsConnector(ctx)

View file

@ -28,13 +28,11 @@ class ActionConnectorPipeline(ConnectorPipeline, ActionConnector):
ac_data = c(ac_data) ac_data = c(ac_data)
return ac_data return ac_data
def to_config(self): def to_state(self):
return ActionConnectorPipeline.__name__, [ return ActionConnectorPipeline.__name__, [c.to_state() for c in self.connectors]
c.to_config() for c in self.connectors
]
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
assert ( assert (
type(params) == list type(params) == list
), "ActionConnectorPipeline takes a list of connector params." ), "ActionConnectorPipeline takes a list of connector params."

View file

@ -39,14 +39,14 @@ class ClipRewardAgentConnector(AgentConnector):
) )
return ac_data return ac_data
def to_config(self): def to_state(self):
return ClipRewardAgentConnector.__name__, { return ClipRewardAgentConnector.__name__, {
"sign": self.sign, "sign": self.sign,
"limit": self.limit, "limit": self.limit,
} }
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
return ClipRewardAgentConnector(ctx, **params) return ClipRewardAgentConnector(ctx, **params)

View file

@ -39,11 +39,11 @@ def register_lambda_agent_connector(
ac_data.env_id, ac_data.agent_id, fn(ac_data.data) ac_data.env_id, ac_data.agent_id, fn(ac_data.data)
) )
def to_config(self): def to_state(self):
return name, None return name, None
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
return LambdaAgentConnector(ctx) return LambdaAgentConnector(ctx)
LambdaAgentConnector.__name__ = name LambdaAgentConnector.__name__ = name

View file

@ -56,11 +56,11 @@ class ObsPreprocessorConnector(AgentConnector):
return ac_data return ac_data
def to_config(self): def to_state(self):
return ObsPreprocessorConnector.__name__, {} return ObsPreprocessorConnector.__name__, {}
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
return ObsPreprocessorConnector(ctx, **params) return ObsPreprocessorConnector(ctx, **params)

View file

@ -39,11 +39,11 @@ class AgentConnectorPipeline(ConnectorPipeline, AgentConnector):
ret = c(ret) ret = c(ret)
return ret return ret
def to_config(self): def to_state(self):
return AgentConnectorPipeline.__name__, [c.to_config() for c in self.connectors] return AgentConnectorPipeline.__name__, [c.to_state() for c in self.connectors]
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
assert ( assert (
type(params) == list type(params) == list
), "AgentConnectorPipeline takes a list of connector params." ), "AgentConnectorPipeline takes a list of connector params."

View file

@ -71,11 +71,11 @@ class StateBufferConnector(AgentConnector):
return ac_data return ac_data
def to_config(self): def to_state(self):
return StateBufferConnector.__name__, None return StateBufferConnector.__name__, None
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
return StateBufferConnector(ctx) return StateBufferConnector(ctx)

View file

@ -119,11 +119,11 @@ class ViewRequirementAgentConnector(AgentConnector):
) )
return return_data return return_data
def to_config(self): def to_state(self):
return ViewRequirementAgentConnector.__name__, None return ViewRequirementAgentConnector.__name__, None
@staticmethod @staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]): def from_state(ctx: ConnectorContext, params: List[Any]):
return ViewRequirementAgentConnector(ctx) return ViewRequirementAgentConnector(ctx)

View file

@ -90,7 +90,7 @@ class Connector(abc.ABC):
Connectors may be training-aware, for example, behave slightly differently Connectors may be training-aware, for example, behave slightly differently
during training and inference. during training and inference.
All connectors are required to be serializable and implement to_config(). All connectors are required to be serializable and implement to_state().
""" """
def __init__(self, ctx: ConnectorContext): def __init__(self, ctx: ConnectorContext):
@ -103,10 +103,10 @@ class Connector(abc.ABC):
def __str__(self, indentation: int = 0): def __str__(self, indentation: int = 0):
return " " * indentation + self.__class__.__name__ return " " * indentation + self.__class__.__name__
def to_config(self) -> Tuple[str, List[Any]]: def to_state(self) -> Tuple[str, List[Any]]:
"""Serialize a connector into a JSON serializable Tuple. """Serialize a connector into a JSON serializable Tuple.
to_config is required, so that all Connectors are serializable. to_state is required, so that all Connectors are serializable.
Returns: Returns:
A tuple of connector's name and its serialized states. A tuple of connector's name and its serialized states.
@ -115,10 +115,10 @@ class Connector(abc.ABC):
return NotImplementedError return NotImplementedError
@staticmethod @staticmethod
def from_config(self, ctx: ConnectorContext, params: List[Any]) -> "Connector": def from_state(self, ctx: ConnectorContext, params: List[Any]) -> "Connector":
"""De-serialize a JSON params back into a Connector. """De-serialize a JSON params back into a Connector.
from_config is required, so that all Connectors are serializable. from_state is required, so that all Connectors are serializable.
Args: Args:
ctx: Context for constructing this connector. ctx: Context for constructing this connector.
@ -266,11 +266,11 @@ class ActionConnector(Connector):
to user environments. to user environments.
An action connector transforms a single piece of policy output in An action connector transforms a single piece of policy output in
ActionConnectorDataType format, which is basically PolicyOutputType ActionConnectorDataType format, which is basically PolicyOutputType plus env and
plus env and agent IDs. agent IDs.
Any functions that operates directly on PolicyOutputType can be Any functions that operate directly on PolicyOutputType can be easily adapted
easily adpated into an ActionConnector by using register_lambda_action_connector. into an ActionConnector by using register_lambda_action_connector.
Example: Example:
.. code-block:: python .. code-block:: python
@ -432,4 +432,4 @@ def get_connector(ctx: ConnectorContext, name: str, params: Tuple[Any]) -> Conne
if not _global_registry.contains(RLLIB_CONNECTOR, name): if not _global_registry.contains(RLLIB_CONNECTOR, name):
raise NameError("connector not found.", name) raise NameError("connector not found.", name)
cls = _global_registry.get(RLLIB_CONNECTOR, name) cls = _global_registry.get(RLLIB_CONNECTOR, name)
return cls.from_config(ctx, params) return cls.from_state(ctx, params)

View file

@ -20,7 +20,7 @@ class TestActionConnector(unittest.TestCase):
ctx = ConnectorContext() ctx = ConnectorContext()
connectors = [ConvertToNumpyConnector(ctx)] connectors = [ConvertToNumpyConnector(ctx)]
pipeline = ActionConnectorPipeline(ctx, connectors) pipeline = ActionConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config() name, params = pipeline.to_state()
restored = get_connector(ctx, name, params) restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ActionConnectorPipeline)) self.assertTrue(isinstance(restored, ActionConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector)) self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
@ -29,7 +29,7 @@ class TestActionConnector(unittest.TestCase):
ctx = ConnectorContext() ctx = ConnectorContext()
c = ConvertToNumpyConnector(ctx) c = ConvertToNumpyConnector(ctx)
name, params = c.to_config() name, params = c.to_state()
self.assertEqual(name, "ConvertToNumpyConnector") self.assertEqual(name, "ConvertToNumpyConnector")
@ -50,7 +50,7 @@ class TestActionConnector(unittest.TestCase):
) )
c = NormalizeActionsConnector(ctx) c = NormalizeActionsConnector(ctx)
name, params = c.to_config() name, params = c.to_state()
self.assertEqual(name, "NormalizeActionsConnector") self.assertEqual(name, "NormalizeActionsConnector")
restored = get_connector(ctx, name, params) restored = get_connector(ctx, name, params)
@ -67,7 +67,7 @@ class TestActionConnector(unittest.TestCase):
) )
c = ClipActionsConnector(ctx) c = ClipActionsConnector(ctx)
name, params = c.to_config() name, params = c.to_state()
self.assertEqual(name, "ClipActionsConnector") self.assertEqual(name, "ClipActionsConnector")
restored = get_connector(ctx, name, params) restored = get_connector(ctx, name, params)
@ -84,7 +84,7 @@ class TestActionConnector(unittest.TestCase):
) )
c = ImmutableActionsConnector(ctx) c = ImmutableActionsConnector(ctx)
name, params = c.to_config() name, params = c.to_state()
self.assertEqual(name, "ImmutableActionsConnector") self.assertEqual(name, "ImmutableActionsConnector")
restored = get_connector(ctx, name, params) restored = get_connector(ctx, name, params)

View file

@ -25,7 +25,7 @@ class TestAgentConnector(unittest.TestCase):
ctx = ConnectorContext() ctx = ConnectorContext()
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)] connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
pipeline = AgentConnectorPipeline(ctx, connectors) pipeline = AgentConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config() name, params = pipeline.to_state()
restored = get_connector(ctx, name, params) restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, AgentConnectorPipeline)) self.assertTrue(isinstance(restored, AgentConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector)) self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
@ -42,7 +42,7 @@ class TestAgentConnector(unittest.TestCase):
ctx = ConnectorContext(config={}, observation_space=obs_space) ctx = ConnectorContext(config={}, observation_space=obs_space)
c = ObsPreprocessorConnector(ctx) c = ObsPreprocessorConnector(ctx)
name, params = c.to_config() name, params = c.to_state()
restored = get_connector(ctx, name, params) restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ObsPreprocessorConnector)) self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
@ -70,7 +70,7 @@ class TestAgentConnector(unittest.TestCase):
ctx = ConnectorContext() ctx = ConnectorContext()
c = ClipRewardAgentConnector(ctx, limit=2.0) c = ClipRewardAgentConnector(ctx, limit=2.0)
name, params = c.to_config() name, params = c.to_state()
self.assertEqual(name, "ClipRewardAgentConnector") self.assertEqual(name, "ClipRewardAgentConnector")
self.assertAlmostEqual(params["limit"], 2.0) self.assertAlmostEqual(params["limit"], 2.0)
@ -95,7 +95,7 @@ class TestAgentConnector(unittest.TestCase):
c = FlattenDataAgentConnector(ctx) c = FlattenDataAgentConnector(ctx)
name, params = c.to_config() name, params = c.to_state()
restored = get_connector(ctx, name, params) restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, FlattenDataAgentConnector)) self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
@ -413,11 +413,11 @@ class TestViewRequirementConnector(unittest.TestCase):
] ]
agent_connector = AgentConnectorPipeline(ctx, connectors) agent_connector = AgentConnectorPipeline(ctx, connectors)
name, params = agent_connector.to_config() name, params = agent_connector.to_state()
restored = get_connector(ctx, name, params) restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, AgentConnectorPipeline)) self.assertTrue(isinstance(restored, AgentConnectorPipeline))
for cidx, c in enumerate(connectors): for cidx, c in enumerate(connectors):
check(restored.connectors[cidx].to_config(), c.to_config()) check(restored.connectors[cidx].to_state(), c.to_state())
# simulate a rollout # simulate a rollout
n_steps = 10 n_steps = 10

View file

@ -5,15 +5,15 @@ from ray.rllib.connectors.connector import Connector, ConnectorPipeline
class TestConnectorPipeline(unittest.TestCase): class TestConnectorPipeline(unittest.TestCase):
class Tom(Connector): class Tom(Connector):
def to_config(): def to_state():
return "tom" return "tom"
class Bob(Connector): class Bob(Connector):
def to_config(): def to_state():
return "bob" return "bob"
class Mary(Connector): class Mary(Connector):
def to_config(): def to_state():
return "mary" return "mary"
class MockConnectorPipeline(ConnectorPipeline): class MockConnectorPipeline(ConnectorPipeline):

View file

@ -767,9 +767,9 @@ class Policy(metaclass=ABCMeta):
# Checkpoint connectors state as well if enabled. # Checkpoint connectors state as well if enabled.
connector_configs = {} connector_configs = {}
if self.agent_connectors: if self.agent_connectors:
connector_configs["agent"] = self.agent_connectors.to_config() connector_configs["agent"] = self.agent_connectors.to_state()
if self.action_connectors: if self.action_connectors:
connector_configs["action"] = self.action_connectors.to_config() connector_configs["action"] = self.action_connectors.to_state()
state["connector_configs"] = connector_configs state["connector_configs"] = connector_configs
return state return state