mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] Rename connector's from/to config methods to better reflect that they include state. (#27806)
This commit is contained in:
parent
328e6ac2f4
commit
2ce80d8163
16 changed files with 50 additions and 52 deletions
|
@ -29,11 +29,11 @@ class ClipActionsConnector(ActionConnector):
|
|||
(clip_action(actions, self._action_space_struct), states, fetches),
|
||||
)
|
||||
|
||||
def to_config(self):
|
||||
def to_state(self):
|
||||
return ClipActionsConnector.__name__, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
return ClipActionsConnector(ctx)
|
||||
|
||||
|
||||
|
|
|
@ -28,11 +28,11 @@ class ImmutableActionsConnector(ActionConnector):
|
|||
(actions, states, fetches),
|
||||
)
|
||||
|
||||
def to_config(self):
|
||||
def to_state(self):
|
||||
return ImmutableActionsConnector.__name__, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
return ImmutableActionsConnector(ctx)
|
||||
|
||||
|
||||
|
|
|
@ -47,11 +47,11 @@ def register_lambda_action_connector(
|
|||
fn(actions, states, fetches),
|
||||
)
|
||||
|
||||
def to_config(self):
|
||||
def to_state(self):
|
||||
return name, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
return LambdaActionConnector(ctx)
|
||||
|
||||
LambdaActionConnector.__name__ = name
|
||||
|
|
|
@ -32,11 +32,11 @@ class NormalizeActionsConnector(ActionConnector):
|
|||
(unsquash_action(actions, self._action_space_struct), states, fetches),
|
||||
)
|
||||
|
||||
def to_config(self):
|
||||
def to_state(self):
|
||||
return NormalizeActionsConnector.__name__, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
return NormalizeActionsConnector(ctx)
|
||||
|
||||
|
||||
|
|
|
@ -28,13 +28,11 @@ class ActionConnectorPipeline(ConnectorPipeline, ActionConnector):
|
|||
ac_data = c(ac_data)
|
||||
return ac_data
|
||||
|
||||
def to_config(self):
|
||||
return ActionConnectorPipeline.__name__, [
|
||||
c.to_config() for c in self.connectors
|
||||
]
|
||||
def to_state(self):
|
||||
return ActionConnectorPipeline.__name__, [c.to_state() for c in self.connectors]
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
assert (
|
||||
type(params) == list
|
||||
), "ActionConnectorPipeline takes a list of connector params."
|
||||
|
|
|
@ -39,14 +39,14 @@ class ClipRewardAgentConnector(AgentConnector):
|
|||
)
|
||||
return ac_data
|
||||
|
||||
def to_config(self):
|
||||
def to_state(self):
|
||||
return ClipRewardAgentConnector.__name__, {
|
||||
"sign": self.sign,
|
||||
"limit": self.limit,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
return ClipRewardAgentConnector(ctx, **params)
|
||||
|
||||
|
||||
|
|
|
@ -39,11 +39,11 @@ def register_lambda_agent_connector(
|
|||
ac_data.env_id, ac_data.agent_id, fn(ac_data.data)
|
||||
)
|
||||
|
||||
def to_config(self):
|
||||
def to_state(self):
|
||||
return name, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
return LambdaAgentConnector(ctx)
|
||||
|
||||
LambdaAgentConnector.__name__ = name
|
||||
|
|
|
@ -56,11 +56,11 @@ class ObsPreprocessorConnector(AgentConnector):
|
|||
|
||||
return ac_data
|
||||
|
||||
def to_config(self):
|
||||
def to_state(self):
|
||||
return ObsPreprocessorConnector.__name__, {}
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
return ObsPreprocessorConnector(ctx, **params)
|
||||
|
||||
|
||||
|
|
|
@ -39,11 +39,11 @@ class AgentConnectorPipeline(ConnectorPipeline, AgentConnector):
|
|||
ret = c(ret)
|
||||
return ret
|
||||
|
||||
def to_config(self):
|
||||
return AgentConnectorPipeline.__name__, [c.to_config() for c in self.connectors]
|
||||
def to_state(self):
|
||||
return AgentConnectorPipeline.__name__, [c.to_state() for c in self.connectors]
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
assert (
|
||||
type(params) == list
|
||||
), "AgentConnectorPipeline takes a list of connector params."
|
||||
|
|
|
@ -71,11 +71,11 @@ class StateBufferConnector(AgentConnector):
|
|||
|
||||
return ac_data
|
||||
|
||||
def to_config(self):
|
||||
def to_state(self):
|
||||
return StateBufferConnector.__name__, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
return StateBufferConnector(ctx)
|
||||
|
||||
|
||||
|
|
|
@ -119,11 +119,11 @@ class ViewRequirementAgentConnector(AgentConnector):
|
|||
)
|
||||
return return_data
|
||||
|
||||
def to_config(self):
|
||||
def to_state(self):
|
||||
return ViewRequirementAgentConnector.__name__, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||
return ViewRequirementAgentConnector(ctx)
|
||||
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ class Connector(abc.ABC):
|
|||
Connectors may be training-aware, for example, behave slightly differently
|
||||
during training and inference.
|
||||
|
||||
All connectors are required to be serializable and implement to_config().
|
||||
All connectors are required to be serializable and implement to_state().
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: ConnectorContext):
|
||||
|
@ -103,10 +103,10 @@ class Connector(abc.ABC):
|
|||
def __str__(self, indentation: int = 0):
|
||||
return " " * indentation + self.__class__.__name__
|
||||
|
||||
def to_config(self) -> Tuple[str, List[Any]]:
|
||||
def to_state(self) -> Tuple[str, List[Any]]:
|
||||
"""Serialize a connector into a JSON serializable Tuple.
|
||||
|
||||
to_config is required, so that all Connectors are serializable.
|
||||
to_state is required, so that all Connectors are serializable.
|
||||
|
||||
Returns:
|
||||
A tuple of connector's name and its serialized states.
|
||||
|
@ -115,10 +115,10 @@ class Connector(abc.ABC):
|
|||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def from_config(self, ctx: ConnectorContext, params: List[Any]) -> "Connector":
|
||||
def from_state(self, ctx: ConnectorContext, params: List[Any]) -> "Connector":
|
||||
"""De-serialize a JSON params back into a Connector.
|
||||
|
||||
from_config is required, so that all Connectors are serializable.
|
||||
from_state is required, so that all Connectors are serializable.
|
||||
|
||||
Args:
|
||||
ctx: Context for constructing this connector.
|
||||
|
@ -266,11 +266,11 @@ class ActionConnector(Connector):
|
|||
to user environments.
|
||||
|
||||
An action connector transforms a single piece of policy output in
|
||||
ActionConnectorDataType format, which is basically PolicyOutputType
|
||||
plus env and agent IDs.
|
||||
ActionConnectorDataType format, which is basically PolicyOutputType plus env and
|
||||
agent IDs.
|
||||
|
||||
Any functions that operates directly on PolicyOutputType can be
|
||||
easily adpated into an ActionConnector by using register_lambda_action_connector.
|
||||
Any functions that operate directly on PolicyOutputType can be easily adapted
|
||||
into an ActionConnector by using register_lambda_action_connector.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
@ -432,4 +432,4 @@ def get_connector(ctx: ConnectorContext, name: str, params: Tuple[Any]) -> Conne
|
|||
if not _global_registry.contains(RLLIB_CONNECTOR, name):
|
||||
raise NameError("connector not found.", name)
|
||||
cls = _global_registry.get(RLLIB_CONNECTOR, name)
|
||||
return cls.from_config(ctx, params)
|
||||
return cls.from_state(ctx, params)
|
||||
|
|
|
@ -20,7 +20,7 @@ class TestActionConnector(unittest.TestCase):
|
|||
ctx = ConnectorContext()
|
||||
connectors = [ConvertToNumpyConnector(ctx)]
|
||||
pipeline = ActionConnectorPipeline(ctx, connectors)
|
||||
name, params = pipeline.to_config()
|
||||
name, params = pipeline.to_state()
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
|
||||
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
|
||||
|
@ -29,7 +29,7 @@ class TestActionConnector(unittest.TestCase):
|
|||
ctx = ConnectorContext()
|
||||
c = ConvertToNumpyConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
name, params = c.to_state()
|
||||
|
||||
self.assertEqual(name, "ConvertToNumpyConnector")
|
||||
|
||||
|
@ -50,7 +50,7 @@ class TestActionConnector(unittest.TestCase):
|
|||
)
|
||||
c = NormalizeActionsConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
name, params = c.to_state()
|
||||
self.assertEqual(name, "NormalizeActionsConnector")
|
||||
|
||||
restored = get_connector(ctx, name, params)
|
||||
|
@ -67,7 +67,7 @@ class TestActionConnector(unittest.TestCase):
|
|||
)
|
||||
c = ClipActionsConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
name, params = c.to_state()
|
||||
self.assertEqual(name, "ClipActionsConnector")
|
||||
|
||||
restored = get_connector(ctx, name, params)
|
||||
|
@ -84,7 +84,7 @@ class TestActionConnector(unittest.TestCase):
|
|||
)
|
||||
c = ImmutableActionsConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
name, params = c.to_state()
|
||||
self.assertEqual(name, "ImmutableActionsConnector")
|
||||
|
||||
restored = get_connector(ctx, name, params)
|
||||
|
|
|
@ -25,7 +25,7 @@ class TestAgentConnector(unittest.TestCase):
|
|||
ctx = ConnectorContext()
|
||||
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
|
||||
pipeline = AgentConnectorPipeline(ctx, connectors)
|
||||
name, params = pipeline.to_config()
|
||||
name, params = pipeline.to_state()
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
|
||||
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
|
||||
|
@ -42,7 +42,7 @@ class TestAgentConnector(unittest.TestCase):
|
|||
ctx = ConnectorContext(config={}, observation_space=obs_space)
|
||||
|
||||
c = ObsPreprocessorConnector(ctx)
|
||||
name, params = c.to_config()
|
||||
name, params = c.to_state()
|
||||
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
|
||||
|
@ -70,7 +70,7 @@ class TestAgentConnector(unittest.TestCase):
|
|||
ctx = ConnectorContext()
|
||||
|
||||
c = ClipRewardAgentConnector(ctx, limit=2.0)
|
||||
name, params = c.to_config()
|
||||
name, params = c.to_state()
|
||||
|
||||
self.assertEqual(name, "ClipRewardAgentConnector")
|
||||
self.assertAlmostEqual(params["limit"], 2.0)
|
||||
|
@ -95,7 +95,7 @@ class TestAgentConnector(unittest.TestCase):
|
|||
|
||||
c = FlattenDataAgentConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
name, params = c.to_state()
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
|
||||
|
||||
|
@ -413,11 +413,11 @@ class TestViewRequirementConnector(unittest.TestCase):
|
|||
]
|
||||
agent_connector = AgentConnectorPipeline(ctx, connectors)
|
||||
|
||||
name, params = agent_connector.to_config()
|
||||
name, params = agent_connector.to_state()
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
|
||||
for cidx, c in enumerate(connectors):
|
||||
check(restored.connectors[cidx].to_config(), c.to_config())
|
||||
check(restored.connectors[cidx].to_state(), c.to_state())
|
||||
|
||||
# simulate a rollout
|
||||
n_steps = 10
|
||||
|
|
|
@ -5,15 +5,15 @@ from ray.rllib.connectors.connector import Connector, ConnectorPipeline
|
|||
|
||||
class TestConnectorPipeline(unittest.TestCase):
|
||||
class Tom(Connector):
|
||||
def to_config():
|
||||
def to_state():
|
||||
return "tom"
|
||||
|
||||
class Bob(Connector):
|
||||
def to_config():
|
||||
def to_state():
|
||||
return "bob"
|
||||
|
||||
class Mary(Connector):
|
||||
def to_config():
|
||||
def to_state():
|
||||
return "mary"
|
||||
|
||||
class MockConnectorPipeline(ConnectorPipeline):
|
||||
|
|
|
@ -767,9 +767,9 @@ class Policy(metaclass=ABCMeta):
|
|||
# Checkpoint connectors state as well if enabled.
|
||||
connector_configs = {}
|
||||
if self.agent_connectors:
|
||||
connector_configs["agent"] = self.agent_connectors.to_config()
|
||||
connector_configs["agent"] = self.agent_connectors.to_state()
|
||||
if self.action_connectors:
|
||||
connector_configs["action"] = self.action_connectors.to_config()
|
||||
connector_configs["action"] = self.action_connectors.to_state()
|
||||
state["connector_configs"] = connector_configs
|
||||
return state
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue