ray/rllib/connectors/agent/lambdas.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

82 lines
2.3 KiB
Python
Raw Normal View History

import numpy as np
import tree # dm_tree
from typing import Any, Callable, Dict, List, Type
from ray.rllib.connectors.connector import (
ConnectorContext,
AgentConnector,
register_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import (
AgentConnectorDataType,
TensorStructType,
)
@DeveloperAPI
def register_lambda_agent_connector(
name: str, fn: Callable[[Any], Any]
) -> Type[AgentConnector]:
"""A util to register any simple transforming function as an AgentConnector
The only requirement is that fn should take a single data object and return
a single data object.
Args:
name: Name of the resulting actor connector.
fn: The function that transforms env / agent data.
Returns:
A new AgentConnector class that transforms data using fn.
"""
class LambdaAgentConnector(AgentConnector):
def __call__(
self, ac_data: AgentConnectorDataType
) -> List[AgentConnectorDataType]:
d = ac_data.data
return [AgentConnectorDataType(ac_data.env_id, ac_data.agent_id, fn(d))]
def to_config(self):
return name, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return LambdaAgentConnector(ctx)
LambdaAgentConnector.__name__ = name
LambdaAgentConnector.__qualname__ = name
register_connector(name, LambdaAgentConnector)
return LambdaAgentConnector
@DeveloperAPI
def flatten_data(data: Dict[str, TensorStructType]):
assert (
type(data) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
flattened = {}
for k, v in data.items():
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"):
# Do not flatten infos, actions, and state_out_ columns.
flattened[k] = v
continue
if v is None:
# Keep the same column shape.
flattened[k] = None
continue
flattened[k] = np.array(tree.flatten(v))
return flattened
# Flatten observation data.
FlattenDataAgentConnector = register_lambda_agent_connector(
"FlattenDataAgentConnector", flatten_data
)