ray/rllib/connectors/action/normalize.py

43 lines
1.3 KiB
Python

from typing import Any, List
from ray.rllib.connectors.connector import (
ConnectorContext,
ActionConnector,
register_connector,
)
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.spaces.space_utils import (
get_base_struct_from_space,
unsquash_action,
)
from ray.rllib.utils.typing import ActionConnectorDataType
@DeveloperAPI
class NormalizeActionsConnector(ActionConnector):
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
assert isinstance(
ac_data.output, tuple
), "Action connector requires PolicyOutputType data."
actions, states, fetches = ac_data.output
return ActionConnectorDataType(
ac_data.env_id,
ac_data.agent_id,
(unsquash_action(actions, self._action_space_struct), states, fetches),
)
def to_config(self):
return NormalizeActionsConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return NormalizeActionsConnector(ctx)
register_connector(NormalizeActionsConnector.__name__, NormalizeActionsConnector)