from typing import Any, List from ray.rllib.connectors.connector import ( ConnectorContext, ActionConnector, register_connector, ) from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.spaces.space_utils import ( get_base_struct_from_space, unsquash_action, ) from ray.rllib.utils.typing import ActionConnectorDataType @DeveloperAPI class NormalizeActionsConnector(ActionConnector): def __init__(self, ctx: ConnectorContext): super().__init__(ctx) self._action_space_struct = get_base_struct_from_space(ctx.action_space) def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: assert isinstance( ac_data.output, tuple ), "Action connector requires PolicyOutputType data." actions, states, fetches = ac_data.output return ActionConnectorDataType( ac_data.env_id, ac_data.agent_id, (unsquash_action(actions, self._action_space_struct), states, fetches), ) def to_config(self): return NormalizeActionsConnector.__name__, None @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return NormalizeActionsConnector(ctx) register_connector(NormalizeActionsConnector.__name__, NormalizeActionsConnector)