[RLlib] Rename connector's from/to config methods to better reflect that they include state. (#27806)

This commit is contained in:
Artur Niederfahrenhorst 2022-08-29 14:37:21 +02:00 committed by GitHub
parent 328e6ac2f4
commit 2ce80d8163
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 50 additions and 52 deletions

View file

@ -29,11 +29,11 @@ class ClipActionsConnector(ActionConnector):
(clip_action(actions, self._action_space_struct), states, fetches),
)
def to_config(self):
def to_state(self):
return ClipActionsConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
return ClipActionsConnector(ctx)

View file

@ -28,11 +28,11 @@ class ImmutableActionsConnector(ActionConnector):
(actions, states, fetches),
)
def to_config(self):
def to_state(self):
return ImmutableActionsConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
return ImmutableActionsConnector(ctx)

View file

@ -47,11 +47,11 @@ def register_lambda_action_connector(
fn(actions, states, fetches),
)
def to_config(self):
def to_state(self):
return name, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
return LambdaActionConnector(ctx)
LambdaActionConnector.__name__ = name

View file

@ -32,11 +32,11 @@ class NormalizeActionsConnector(ActionConnector):
(unsquash_action(actions, self._action_space_struct), states, fetches),
)
def to_config(self):
def to_state(self):
return NormalizeActionsConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
return NormalizeActionsConnector(ctx)

View file

@ -28,13 +28,11 @@ class ActionConnectorPipeline(ConnectorPipeline, ActionConnector):
ac_data = c(ac_data)
return ac_data
def to_config(self):
return ActionConnectorPipeline.__name__, [
c.to_config() for c in self.connectors
]
def to_state(self):
return ActionConnectorPipeline.__name__, [c.to_state() for c in self.connectors]
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
assert (
type(params) == list
), "ActionConnectorPipeline takes a list of connector params."

View file

@ -39,14 +39,14 @@ class ClipRewardAgentConnector(AgentConnector):
)
return ac_data
def to_config(self):
def to_state(self):
return ClipRewardAgentConnector.__name__, {
"sign": self.sign,
"limit": self.limit,
}
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
return ClipRewardAgentConnector(ctx, **params)

View file

@ -39,11 +39,11 @@ def register_lambda_agent_connector(
ac_data.env_id, ac_data.agent_id, fn(ac_data.data)
)
def to_config(self):
def to_state(self):
return name, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
return LambdaAgentConnector(ctx)
LambdaAgentConnector.__name__ = name

View file

@ -56,11 +56,11 @@ class ObsPreprocessorConnector(AgentConnector):
return ac_data
def to_config(self):
def to_state(self):
return ObsPreprocessorConnector.__name__, {}
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
return ObsPreprocessorConnector(ctx, **params)

View file

@ -39,11 +39,11 @@ class AgentConnectorPipeline(ConnectorPipeline, AgentConnector):
ret = c(ret)
return ret
def to_config(self):
return AgentConnectorPipeline.__name__, [c.to_config() for c in self.connectors]
def to_state(self):
return AgentConnectorPipeline.__name__, [c.to_state() for c in self.connectors]
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
assert (
type(params) == list
), "AgentConnectorPipeline takes a list of connector params."

View file

@ -71,11 +71,11 @@ class StateBufferConnector(AgentConnector):
return ac_data
def to_config(self):
def to_state(self):
return StateBufferConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
return StateBufferConnector(ctx)

View file

@ -119,11 +119,11 @@ class ViewRequirementAgentConnector(AgentConnector):
)
return return_data
def to_config(self):
def to_state(self):
return ViewRequirementAgentConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
def from_state(ctx: ConnectorContext, params: List[Any]):
return ViewRequirementAgentConnector(ctx)

View file

@ -90,7 +90,7 @@ class Connector(abc.ABC):
Connectors may be training-aware, for example, behave slightly differently
during training and inference.
All connectors are required to be serializable and implement to_config().
All connectors are required to be serializable and implement to_state().
"""
def __init__(self, ctx: ConnectorContext):
@ -103,10 +103,10 @@ class Connector(abc.ABC):
def __str__(self, indentation: int = 0):
return " " * indentation + self.__class__.__name__
def to_config(self) -> Tuple[str, List[Any]]:
def to_state(self) -> Tuple[str, List[Any]]:
"""Serialize a connector into a JSON serializable Tuple.
to_config is required, so that all Connectors are serializable.
to_state is required, so that all Connectors are serializable.
Returns:
A tuple of connector's name and its serialized states.
@ -115,10 +115,10 @@ class Connector(abc.ABC):
return NotImplementedError
@staticmethod
def from_config(self, ctx: ConnectorContext, params: List[Any]) -> "Connector":
def from_state(self, ctx: ConnectorContext, params: List[Any]) -> "Connector":
"""De-serialize a JSON params back into a Connector.
from_config is required, so that all Connectors are serializable.
from_state is required, so that all Connectors are serializable.
Args:
ctx: Context for constructing this connector.
@ -266,11 +266,11 @@ class ActionConnector(Connector):
to user environments.
An action connector transforms a single piece of policy output in
ActionConnectorDataType format, which is basically PolicyOutputType
plus env and agent IDs.
ActionConnectorDataType format, which is basically PolicyOutputType plus env and
agent IDs.
Any functions that operates directly on PolicyOutputType can be
easily adpated into an ActionConnector by using register_lambda_action_connector.
Any functions that operate directly on PolicyOutputType can be easily adapted
into an ActionConnector by using register_lambda_action_connector.
Example:
.. code-block:: python
@ -432,4 +432,4 @@ def get_connector(ctx: ConnectorContext, name: str, params: Tuple[Any]) -> Conne
if not _global_registry.contains(RLLIB_CONNECTOR, name):
raise NameError("connector not found.", name)
cls = _global_registry.get(RLLIB_CONNECTOR, name)
return cls.from_config(ctx, params)
return cls.from_state(ctx, params)

View file

@ -20,7 +20,7 @@ class TestActionConnector(unittest.TestCase):
ctx = ConnectorContext()
connectors = [ConvertToNumpyConnector(ctx)]
pipeline = ActionConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config()
name, params = pipeline.to_state()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
@ -29,7 +29,7 @@ class TestActionConnector(unittest.TestCase):
ctx = ConnectorContext()
c = ConvertToNumpyConnector(ctx)
name, params = c.to_config()
name, params = c.to_state()
self.assertEqual(name, "ConvertToNumpyConnector")
@ -50,7 +50,7 @@ class TestActionConnector(unittest.TestCase):
)
c = NormalizeActionsConnector(ctx)
name, params = c.to_config()
name, params = c.to_state()
self.assertEqual(name, "NormalizeActionsConnector")
restored = get_connector(ctx, name, params)
@ -67,7 +67,7 @@ class TestActionConnector(unittest.TestCase):
)
c = ClipActionsConnector(ctx)
name, params = c.to_config()
name, params = c.to_state()
self.assertEqual(name, "ClipActionsConnector")
restored = get_connector(ctx, name, params)
@ -84,7 +84,7 @@ class TestActionConnector(unittest.TestCase):
)
c = ImmutableActionsConnector(ctx)
name, params = c.to_config()
name, params = c.to_state()
self.assertEqual(name, "ImmutableActionsConnector")
restored = get_connector(ctx, name, params)

View file

@ -25,7 +25,7 @@ class TestAgentConnector(unittest.TestCase):
ctx = ConnectorContext()
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
pipeline = AgentConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config()
name, params = pipeline.to_state()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
@ -42,7 +42,7 @@ class TestAgentConnector(unittest.TestCase):
ctx = ConnectorContext(config={}, observation_space=obs_space)
c = ObsPreprocessorConnector(ctx)
name, params = c.to_config()
name, params = c.to_state()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
@ -70,7 +70,7 @@ class TestAgentConnector(unittest.TestCase):
ctx = ConnectorContext()
c = ClipRewardAgentConnector(ctx, limit=2.0)
name, params = c.to_config()
name, params = c.to_state()
self.assertEqual(name, "ClipRewardAgentConnector")
self.assertAlmostEqual(params["limit"], 2.0)
@ -95,7 +95,7 @@ class TestAgentConnector(unittest.TestCase):
c = FlattenDataAgentConnector(ctx)
name, params = c.to_config()
name, params = c.to_state()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
@ -413,11 +413,11 @@ class TestViewRequirementConnector(unittest.TestCase):
]
agent_connector = AgentConnectorPipeline(ctx, connectors)
name, params = agent_connector.to_config()
name, params = agent_connector.to_state()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
for cidx, c in enumerate(connectors):
check(restored.connectors[cidx].to_config(), c.to_config())
check(restored.connectors[cidx].to_state(), c.to_state())
# simulate a rollout
n_steps = 10

View file

@ -5,15 +5,15 @@ from ray.rllib.connectors.connector import Connector, ConnectorPipeline
class TestConnectorPipeline(unittest.TestCase):
class Tom(Connector):
def to_config():
def to_state():
return "tom"
class Bob(Connector):
def to_config():
def to_state():
return "bob"
class Mary(Connector):
def to_config():
def to_state():
return "mary"
class MockConnectorPipeline(ConnectorPipeline):

View file

@ -767,9 +767,9 @@ class Policy(metaclass=ABCMeta):
# Checkpoint connectors state as well if enabled.
connector_configs = {}
if self.agent_connectors:
connector_configs["agent"] = self.agent_connectors.to_config()
connector_configs["agent"] = self.agent_connectors.to_state()
if self.action_connectors:
connector_configs["action"] = self.action_connectors.to_config()
connector_configs["action"] = self.action_connectors.to_state()
state["connector_configs"] = connector_configs
return state