mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
82 lines
2.3 KiB
Python
82 lines
2.3 KiB
Python
![]() |
import numpy as np
|
||
|
import tree # dm_tree
|
||
|
from typing import Any, Callable, Dict, List, Type
|
||
|
|
||
|
from ray.rllib.connectors.connector import (
|
||
|
ConnectorContext,
|
||
|
AgentConnector,
|
||
|
register_connector,
|
||
|
)
|
||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||
|
from ray.rllib.utils.typing import (
|
||
|
AgentConnectorDataType,
|
||
|
TensorStructType,
|
||
|
)
|
||
|
|
||
|
|
||
|
@DeveloperAPI
|
||
|
def register_lambda_agent_connector(
|
||
|
name: str, fn: Callable[[Any], Any]
|
||
|
) -> Type[AgentConnector]:
|
||
|
"""A util to register any simple transforming function as an AgentConnector
|
||
|
|
||
|
The only requirement is that fn should take a single data object and return
|
||
|
a single data object.
|
||
|
|
||
|
Args:
|
||
|
name: Name of the resulting actor connector.
|
||
|
fn: The function that transforms env / agent data.
|
||
|
|
||
|
Returns:
|
||
|
A new AgentConnector class that transforms data using fn.
|
||
|
"""
|
||
|
|
||
|
class LambdaAgentConnector(AgentConnector):
|
||
|
def __call__(
|
||
|
self, ac_data: AgentConnectorDataType
|
||
|
) -> List[AgentConnectorDataType]:
|
||
|
d = ac_data.data
|
||
|
return [AgentConnectorDataType(ac_data.env_id, ac_data.agent_id, fn(d))]
|
||
|
|
||
|
def to_config(self):
|
||
|
return name, None
|
||
|
|
||
|
@staticmethod
|
||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||
|
return LambdaAgentConnector(ctx)
|
||
|
|
||
|
LambdaAgentConnector.__name__ = name
|
||
|
LambdaAgentConnector.__qualname__ = name
|
||
|
|
||
|
register_connector(name, LambdaAgentConnector)
|
||
|
|
||
|
return LambdaAgentConnector
|
||
|
|
||
|
|
||
|
@DeveloperAPI
|
||
|
def flatten_data(data: Dict[str, TensorStructType]):
|
||
|
assert (
|
||
|
type(data) == dict
|
||
|
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||
|
|
||
|
flattened = {}
|
||
|
for k, v in data.items():
|
||
|
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"):
|
||
|
# Do not flatten infos, actions, and state_out_ columns.
|
||
|
flattened[k] = v
|
||
|
continue
|
||
|
if v is None:
|
||
|
# Keep the same column shape.
|
||
|
flattened[k] = None
|
||
|
continue
|
||
|
flattened[k] = np.array(tree.flatten(v))
|
||
|
|
||
|
return flattened
|
||
|
|
||
|
|
||
|
# Flatten observation data.
|
||
|
FlattenDataAgentConnector = register_lambda_agent_connector(
|
||
|
"FlattenDataAgentConnector", flatten_data
|
||
|
)
|