2022-07-09 01:06:24 -07:00
|
|
|
from typing import Any, Callable, List, Type
|
2022-06-29 23:44:10 -07:00
|
|
|
|
2022-06-07 10:18:14 -07:00
|
|
|
import numpy as np
|
|
|
|
import tree # dm_tree
|
|
|
|
|
|
|
|
from ray.rllib.connectors.connector import (
|
|
|
|
AgentConnector,
|
2022-06-29 23:44:10 -07:00
|
|
|
ConnectorContext,
|
2022-06-07 10:18:14 -07:00
|
|
|
register_connector,
|
|
|
|
)
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
AgentConnectorDataType,
|
2022-06-29 23:44:10 -07:00
|
|
|
AgentConnectorsOutput,
|
2022-06-07 10:18:14 -07:00
|
|
|
)
|
2022-06-29 23:44:10 -07:00
|
|
|
from ray.util.annotations import PublicAPI
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
@PublicAPI(stability="alpha")
|
2022-06-07 10:18:14 -07:00
|
|
|
def register_lambda_agent_connector(
|
|
|
|
name: str, fn: Callable[[Any], Any]
|
|
|
|
) -> Type[AgentConnector]:
|
|
|
|
"""A util to register any simple transforming function as an AgentConnector
|
|
|
|
|
|
|
|
The only requirement is that fn should take a single data object and return
|
|
|
|
a single data object.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: Name of the resulting actor connector.
|
|
|
|
fn: The function that transforms env / agent data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A new AgentConnector class that transforms data using fn.
|
|
|
|
"""
|
|
|
|
|
|
|
|
class LambdaAgentConnector(AgentConnector):
|
2022-06-29 23:44:10 -07:00
|
|
|
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
|
|
|
return AgentConnectorDataType(
|
|
|
|
ac_data.env_id, ac_data.agent_id, fn(ac_data.data)
|
|
|
|
)
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
def to_config(self):
|
|
|
|
return name, None
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
|
|
|
return LambdaAgentConnector(ctx)
|
|
|
|
|
|
|
|
LambdaAgentConnector.__name__ = name
|
|
|
|
LambdaAgentConnector.__qualname__ = name
|
|
|
|
|
|
|
|
register_connector(name, LambdaAgentConnector)
|
|
|
|
|
|
|
|
return LambdaAgentConnector
|
|
|
|
|
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
@PublicAPI(stability="alpha")
|
2022-07-09 01:06:24 -07:00
|
|
|
def flatten_data(data: AgentConnectorsOutput):
|
2022-06-29 23:44:10 -07:00
|
|
|
assert isinstance(
|
2022-07-09 01:06:24 -07:00
|
|
|
data, AgentConnectorsOutput
|
|
|
|
), "Single agent data must be of type AgentConnectorsOutput"
|
|
|
|
|
|
|
|
for_training = data.for_training
|
|
|
|
for_action = data.for_action
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
flattened = {}
|
2022-07-09 01:06:24 -07:00
|
|
|
for k, v in for_action.items():
|
2022-06-07 10:18:14 -07:00
|
|
|
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"):
|
|
|
|
# Do not flatten infos, actions, and state_out_ columns.
|
|
|
|
flattened[k] = v
|
|
|
|
continue
|
|
|
|
if v is None:
|
|
|
|
# Keep the same column shape.
|
|
|
|
flattened[k] = None
|
|
|
|
continue
|
|
|
|
flattened[k] = np.array(tree.flatten(v))
|
2022-06-29 23:44:10 -07:00
|
|
|
flattened = SampleBatch(flattened, is_training=False)
|
2022-06-07 10:18:14 -07:00
|
|
|
|
2022-07-09 01:06:24 -07:00
|
|
|
return AgentConnectorsOutput(for_training, flattened)
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
# Agent connector to build and return a flattened observation SampleBatch
|
|
|
|
# in addition to the original input dict.
|
|
|
|
FlattenDataAgentConnector = PublicAPI(stability="alpha")(
|
|
|
|
register_lambda_agent_connector("FlattenDataAgentConnector", flatten_data)
|
2022-06-07 10:18:14 -07:00
|
|
|
)
|