mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Rename connector's from/to config methods to better reflect that they include state. (#27806)
This commit is contained in:
parent
328e6ac2f4
commit
2ce80d8163
16 changed files with 50 additions and 52 deletions
|
@ -29,11 +29,11 @@ class ClipActionsConnector(ActionConnector):
|
||||||
(clip_action(actions, self._action_space_struct), states, fetches),
|
(clip_action(actions, self._action_space_struct), states, fetches),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return ClipActionsConnector.__name__, None
|
return ClipActionsConnector.__name__, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
return ClipActionsConnector(ctx)
|
return ClipActionsConnector(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,11 +28,11 @@ class ImmutableActionsConnector(ActionConnector):
|
||||||
(actions, states, fetches),
|
(actions, states, fetches),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return ImmutableActionsConnector.__name__, None
|
return ImmutableActionsConnector.__name__, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
return ImmutableActionsConnector(ctx)
|
return ImmutableActionsConnector(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,11 +47,11 @@ def register_lambda_action_connector(
|
||||||
fn(actions, states, fetches),
|
fn(actions, states, fetches),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return name, None
|
return name, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
return LambdaActionConnector(ctx)
|
return LambdaActionConnector(ctx)
|
||||||
|
|
||||||
LambdaActionConnector.__name__ = name
|
LambdaActionConnector.__name__ = name
|
||||||
|
|
|
@ -32,11 +32,11 @@ class NormalizeActionsConnector(ActionConnector):
|
||||||
(unsquash_action(actions, self._action_space_struct), states, fetches),
|
(unsquash_action(actions, self._action_space_struct), states, fetches),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return NormalizeActionsConnector.__name__, None
|
return NormalizeActionsConnector.__name__, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
return NormalizeActionsConnector(ctx)
|
return NormalizeActionsConnector(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,13 +28,11 @@ class ActionConnectorPipeline(ConnectorPipeline, ActionConnector):
|
||||||
ac_data = c(ac_data)
|
ac_data = c(ac_data)
|
||||||
return ac_data
|
return ac_data
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return ActionConnectorPipeline.__name__, [
|
return ActionConnectorPipeline.__name__, [c.to_state() for c in self.connectors]
|
||||||
c.to_config() for c in self.connectors
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
assert (
|
assert (
|
||||||
type(params) == list
|
type(params) == list
|
||||||
), "ActionConnectorPipeline takes a list of connector params."
|
), "ActionConnectorPipeline takes a list of connector params."
|
||||||
|
|
|
@ -39,14 +39,14 @@ class ClipRewardAgentConnector(AgentConnector):
|
||||||
)
|
)
|
||||||
return ac_data
|
return ac_data
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return ClipRewardAgentConnector.__name__, {
|
return ClipRewardAgentConnector.__name__, {
|
||||||
"sign": self.sign,
|
"sign": self.sign,
|
||||||
"limit": self.limit,
|
"limit": self.limit,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
return ClipRewardAgentConnector(ctx, **params)
|
return ClipRewardAgentConnector(ctx, **params)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -39,11 +39,11 @@ def register_lambda_agent_connector(
|
||||||
ac_data.env_id, ac_data.agent_id, fn(ac_data.data)
|
ac_data.env_id, ac_data.agent_id, fn(ac_data.data)
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return name, None
|
return name, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
return LambdaAgentConnector(ctx)
|
return LambdaAgentConnector(ctx)
|
||||||
|
|
||||||
LambdaAgentConnector.__name__ = name
|
LambdaAgentConnector.__name__ = name
|
||||||
|
|
|
@ -56,11 +56,11 @@ class ObsPreprocessorConnector(AgentConnector):
|
||||||
|
|
||||||
return ac_data
|
return ac_data
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return ObsPreprocessorConnector.__name__, {}
|
return ObsPreprocessorConnector.__name__, {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
return ObsPreprocessorConnector(ctx, **params)
|
return ObsPreprocessorConnector(ctx, **params)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -39,11 +39,11 @@ class AgentConnectorPipeline(ConnectorPipeline, AgentConnector):
|
||||||
ret = c(ret)
|
ret = c(ret)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return AgentConnectorPipeline.__name__, [c.to_config() for c in self.connectors]
|
return AgentConnectorPipeline.__name__, [c.to_state() for c in self.connectors]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
assert (
|
assert (
|
||||||
type(params) == list
|
type(params) == list
|
||||||
), "AgentConnectorPipeline takes a list of connector params."
|
), "AgentConnectorPipeline takes a list of connector params."
|
||||||
|
|
|
@ -71,11 +71,11 @@ class StateBufferConnector(AgentConnector):
|
||||||
|
|
||||||
return ac_data
|
return ac_data
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return StateBufferConnector.__name__, None
|
return StateBufferConnector.__name__, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
return StateBufferConnector(ctx)
|
return StateBufferConnector(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -119,11 +119,11 @@ class ViewRequirementAgentConnector(AgentConnector):
|
||||||
)
|
)
|
||||||
return return_data
|
return return_data
|
||||||
|
|
||||||
def to_config(self):
|
def to_state(self):
|
||||||
return ViewRequirementAgentConnector.__name__, None
|
return ViewRequirementAgentConnector.__name__, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
def from_state(ctx: ConnectorContext, params: List[Any]):
|
||||||
return ViewRequirementAgentConnector(ctx)
|
return ViewRequirementAgentConnector(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ class Connector(abc.ABC):
|
||||||
Connectors may be training-aware, for example, behave slightly differently
|
Connectors may be training-aware, for example, behave slightly differently
|
||||||
during training and inference.
|
during training and inference.
|
||||||
|
|
||||||
All connectors are required to be serializable and implement to_config().
|
All connectors are required to be serializable and implement to_state().
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ctx: ConnectorContext):
|
def __init__(self, ctx: ConnectorContext):
|
||||||
|
@ -103,10 +103,10 @@ class Connector(abc.ABC):
|
||||||
def __str__(self, indentation: int = 0):
|
def __str__(self, indentation: int = 0):
|
||||||
return " " * indentation + self.__class__.__name__
|
return " " * indentation + self.__class__.__name__
|
||||||
|
|
||||||
def to_config(self) -> Tuple[str, List[Any]]:
|
def to_state(self) -> Tuple[str, List[Any]]:
|
||||||
"""Serialize a connector into a JSON serializable Tuple.
|
"""Serialize a connector into a JSON serializable Tuple.
|
||||||
|
|
||||||
to_config is required, so that all Connectors are serializable.
|
to_state is required, so that all Connectors are serializable.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of connector's name and its serialized states.
|
A tuple of connector's name and its serialized states.
|
||||||
|
@ -115,10 +115,10 @@ class Connector(abc.ABC):
|
||||||
return NotImplementedError
|
return NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(self, ctx: ConnectorContext, params: List[Any]) -> "Connector":
|
def from_state(self, ctx: ConnectorContext, params: List[Any]) -> "Connector":
|
||||||
"""De-serialize a JSON params back into a Connector.
|
"""De-serialize a JSON params back into a Connector.
|
||||||
|
|
||||||
from_config is required, so that all Connectors are serializable.
|
from_state is required, so that all Connectors are serializable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctx: Context for constructing this connector.
|
ctx: Context for constructing this connector.
|
||||||
|
@ -266,11 +266,11 @@ class ActionConnector(Connector):
|
||||||
to user environments.
|
to user environments.
|
||||||
|
|
||||||
An action connector transforms a single piece of policy output in
|
An action connector transforms a single piece of policy output in
|
||||||
ActionConnectorDataType format, which is basically PolicyOutputType
|
ActionConnectorDataType format, which is basically PolicyOutputType plus env and
|
||||||
plus env and agent IDs.
|
agent IDs.
|
||||||
|
|
||||||
Any functions that operates directly on PolicyOutputType can be
|
Any functions that operate directly on PolicyOutputType can be easily adapted
|
||||||
easily adpated into an ActionConnector by using register_lambda_action_connector.
|
into an ActionConnector by using register_lambda_action_connector.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -432,4 +432,4 @@ def get_connector(ctx: ConnectorContext, name: str, params: Tuple[Any]) -> Conne
|
||||||
if not _global_registry.contains(RLLIB_CONNECTOR, name):
|
if not _global_registry.contains(RLLIB_CONNECTOR, name):
|
||||||
raise NameError("connector not found.", name)
|
raise NameError("connector not found.", name)
|
||||||
cls = _global_registry.get(RLLIB_CONNECTOR, name)
|
cls = _global_registry.get(RLLIB_CONNECTOR, name)
|
||||||
return cls.from_config(ctx, params)
|
return cls.from_state(ctx, params)
|
||||||
|
|
|
@ -20,7 +20,7 @@ class TestActionConnector(unittest.TestCase):
|
||||||
ctx = ConnectorContext()
|
ctx = ConnectorContext()
|
||||||
connectors = [ConvertToNumpyConnector(ctx)]
|
connectors = [ConvertToNumpyConnector(ctx)]
|
||||||
pipeline = ActionConnectorPipeline(ctx, connectors)
|
pipeline = ActionConnectorPipeline(ctx, connectors)
|
||||||
name, params = pipeline.to_config()
|
name, params = pipeline.to_state()
|
||||||
restored = get_connector(ctx, name, params)
|
restored = get_connector(ctx, name, params)
|
||||||
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
|
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
|
||||||
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
|
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
|
||||||
|
@ -29,7 +29,7 @@ class TestActionConnector(unittest.TestCase):
|
||||||
ctx = ConnectorContext()
|
ctx = ConnectorContext()
|
||||||
c = ConvertToNumpyConnector(ctx)
|
c = ConvertToNumpyConnector(ctx)
|
||||||
|
|
||||||
name, params = c.to_config()
|
name, params = c.to_state()
|
||||||
|
|
||||||
self.assertEqual(name, "ConvertToNumpyConnector")
|
self.assertEqual(name, "ConvertToNumpyConnector")
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class TestActionConnector(unittest.TestCase):
|
||||||
)
|
)
|
||||||
c = NormalizeActionsConnector(ctx)
|
c = NormalizeActionsConnector(ctx)
|
||||||
|
|
||||||
name, params = c.to_config()
|
name, params = c.to_state()
|
||||||
self.assertEqual(name, "NormalizeActionsConnector")
|
self.assertEqual(name, "NormalizeActionsConnector")
|
||||||
|
|
||||||
restored = get_connector(ctx, name, params)
|
restored = get_connector(ctx, name, params)
|
||||||
|
@ -67,7 +67,7 @@ class TestActionConnector(unittest.TestCase):
|
||||||
)
|
)
|
||||||
c = ClipActionsConnector(ctx)
|
c = ClipActionsConnector(ctx)
|
||||||
|
|
||||||
name, params = c.to_config()
|
name, params = c.to_state()
|
||||||
self.assertEqual(name, "ClipActionsConnector")
|
self.assertEqual(name, "ClipActionsConnector")
|
||||||
|
|
||||||
restored = get_connector(ctx, name, params)
|
restored = get_connector(ctx, name, params)
|
||||||
|
@ -84,7 +84,7 @@ class TestActionConnector(unittest.TestCase):
|
||||||
)
|
)
|
||||||
c = ImmutableActionsConnector(ctx)
|
c = ImmutableActionsConnector(ctx)
|
||||||
|
|
||||||
name, params = c.to_config()
|
name, params = c.to_state()
|
||||||
self.assertEqual(name, "ImmutableActionsConnector")
|
self.assertEqual(name, "ImmutableActionsConnector")
|
||||||
|
|
||||||
restored = get_connector(ctx, name, params)
|
restored = get_connector(ctx, name, params)
|
||||||
|
|
|
@ -25,7 +25,7 @@ class TestAgentConnector(unittest.TestCase):
|
||||||
ctx = ConnectorContext()
|
ctx = ConnectorContext()
|
||||||
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
|
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
|
||||||
pipeline = AgentConnectorPipeline(ctx, connectors)
|
pipeline = AgentConnectorPipeline(ctx, connectors)
|
||||||
name, params = pipeline.to_config()
|
name, params = pipeline.to_state()
|
||||||
restored = get_connector(ctx, name, params)
|
restored = get_connector(ctx, name, params)
|
||||||
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
|
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
|
||||||
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
|
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
|
||||||
|
@ -42,7 +42,7 @@ class TestAgentConnector(unittest.TestCase):
|
||||||
ctx = ConnectorContext(config={}, observation_space=obs_space)
|
ctx = ConnectorContext(config={}, observation_space=obs_space)
|
||||||
|
|
||||||
c = ObsPreprocessorConnector(ctx)
|
c = ObsPreprocessorConnector(ctx)
|
||||||
name, params = c.to_config()
|
name, params = c.to_state()
|
||||||
|
|
||||||
restored = get_connector(ctx, name, params)
|
restored = get_connector(ctx, name, params)
|
||||||
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
|
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
|
||||||
|
@ -70,7 +70,7 @@ class TestAgentConnector(unittest.TestCase):
|
||||||
ctx = ConnectorContext()
|
ctx = ConnectorContext()
|
||||||
|
|
||||||
c = ClipRewardAgentConnector(ctx, limit=2.0)
|
c = ClipRewardAgentConnector(ctx, limit=2.0)
|
||||||
name, params = c.to_config()
|
name, params = c.to_state()
|
||||||
|
|
||||||
self.assertEqual(name, "ClipRewardAgentConnector")
|
self.assertEqual(name, "ClipRewardAgentConnector")
|
||||||
self.assertAlmostEqual(params["limit"], 2.0)
|
self.assertAlmostEqual(params["limit"], 2.0)
|
||||||
|
@ -95,7 +95,7 @@ class TestAgentConnector(unittest.TestCase):
|
||||||
|
|
||||||
c = FlattenDataAgentConnector(ctx)
|
c = FlattenDataAgentConnector(ctx)
|
||||||
|
|
||||||
name, params = c.to_config()
|
name, params = c.to_state()
|
||||||
restored = get_connector(ctx, name, params)
|
restored = get_connector(ctx, name, params)
|
||||||
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
|
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
|
||||||
|
|
||||||
|
@ -413,11 +413,11 @@ class TestViewRequirementConnector(unittest.TestCase):
|
||||||
]
|
]
|
||||||
agent_connector = AgentConnectorPipeline(ctx, connectors)
|
agent_connector = AgentConnectorPipeline(ctx, connectors)
|
||||||
|
|
||||||
name, params = agent_connector.to_config()
|
name, params = agent_connector.to_state()
|
||||||
restored = get_connector(ctx, name, params)
|
restored = get_connector(ctx, name, params)
|
||||||
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
|
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
|
||||||
for cidx, c in enumerate(connectors):
|
for cidx, c in enumerate(connectors):
|
||||||
check(restored.connectors[cidx].to_config(), c.to_config())
|
check(restored.connectors[cidx].to_state(), c.to_state())
|
||||||
|
|
||||||
# simulate a rollout
|
# simulate a rollout
|
||||||
n_steps = 10
|
n_steps = 10
|
||||||
|
|
|
@ -5,15 +5,15 @@ from ray.rllib.connectors.connector import Connector, ConnectorPipeline
|
||||||
|
|
||||||
class TestConnectorPipeline(unittest.TestCase):
|
class TestConnectorPipeline(unittest.TestCase):
|
||||||
class Tom(Connector):
|
class Tom(Connector):
|
||||||
def to_config():
|
def to_state():
|
||||||
return "tom"
|
return "tom"
|
||||||
|
|
||||||
class Bob(Connector):
|
class Bob(Connector):
|
||||||
def to_config():
|
def to_state():
|
||||||
return "bob"
|
return "bob"
|
||||||
|
|
||||||
class Mary(Connector):
|
class Mary(Connector):
|
||||||
def to_config():
|
def to_state():
|
||||||
return "mary"
|
return "mary"
|
||||||
|
|
||||||
class MockConnectorPipeline(ConnectorPipeline):
|
class MockConnectorPipeline(ConnectorPipeline):
|
||||||
|
|
|
@ -767,9 +767,9 @@ class Policy(metaclass=ABCMeta):
|
||||||
# Checkpoint connectors state as well if enabled.
|
# Checkpoint connectors state as well if enabled.
|
||||||
connector_configs = {}
|
connector_configs = {}
|
||||||
if self.agent_connectors:
|
if self.agent_connectors:
|
||||||
connector_configs["agent"] = self.agent_connectors.to_config()
|
connector_configs["agent"] = self.agent_connectors.to_state()
|
||||||
if self.action_connectors:
|
if self.action_connectors:
|
||||||
connector_configs["action"] = self.action_connectors.to_config()
|
connector_configs["action"] = self.action_connectors.to_state()
|
||||||
state["connector_configs"] = connector_configs
|
state["connector_configs"] = connector_configs
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue