ray/rllib/connectors/action/immutable.py

39 lines
1.2 KiB
Python

from typing import Any, List
import tree # pip install dm_tree
from ray.rllib.connectors.connector import (
ActionConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.utils.numpy import make_action_immutable
from ray.rllib.utils.typing import ActionConnectorDataType
from ray.util.annotations import PublicAPI
@PublicAPI(stability="alpha")
class ImmutableActionsConnector(ActionConnector):
def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
assert isinstance(
ac_data.output, tuple
), "Action connector requires PolicyOutputType data."
actions, states, fetches = ac_data.output
tree.traverse(make_action_immutable, actions, top_down=False)
return ActionConnectorDataType(
ac_data.env_id,
ac_data.agent_id,
(actions, states, fetches),
)
def to_config(self):
return ImmutableActionsConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return ImmutableActionsConnector(ctx)
register_connector(ImmutableActionsConnector.__name__, ImmutableActionsConnector)