mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
from typing import Any, List
|
|
|
|
import tree # pip install dm_tree
|
|
|
|
from ray.rllib.connectors.connector import (
|
|
ActionConnector,
|
|
ConnectorContext,
|
|
register_connector,
|
|
)
|
|
from ray.rllib.utils.numpy import make_action_immutable
|
|
from ray.rllib.utils.typing import ActionConnectorDataType
|
|
from ray.util.annotations import PublicAPI
|
|
|
|
|
|
@PublicAPI(stability="alpha")
|
|
class ImmutableActionsConnector(ActionConnector):
|
|
def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
|
assert isinstance(
|
|
ac_data.output, tuple
|
|
), "Action connector requires PolicyOutputType data."
|
|
|
|
actions, states, fetches = ac_data.output
|
|
tree.traverse(make_action_immutable, actions, top_down=False)
|
|
|
|
return ActionConnectorDataType(
|
|
ac_data.env_id,
|
|
ac_data.agent_id,
|
|
(actions, states, fetches),
|
|
)
|
|
|
|
def to_config(self):
|
|
return ImmutableActionsConnector.__name__, None
|
|
|
|
@staticmethod
|
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
|
return ImmutableActionsConnector(ctx)
|
|
|
|
|
|
register_connector(ImmutableActionsConnector.__name__, ImmutableActionsConnector)
|