2022-06-07 10:18:14 -07:00
|
|
|
from typing import Any, Callable, Dict, List, Type
|
|
|
|
|
|
|
|
from ray.rllib.connectors.connector import (
|
|
|
|
ActionConnector,
|
2022-06-29 23:44:10 -07:00
|
|
|
ConnectorContext,
|
2022-06-07 10:18:14 -07:00
|
|
|
register_connector,
|
|
|
|
)
|
|
|
|
from ray.rllib.utils.numpy import convert_to_numpy
|
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
ActionConnectorDataType,
|
|
|
|
PolicyOutputType,
|
|
|
|
StateBatches,
|
|
|
|
TensorStructType,
|
|
|
|
)
|
2022-06-29 23:44:10 -07:00
|
|
|
from ray.util.annotations import PublicAPI
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
@PublicAPI(stability="alpha")
|
2022-06-07 10:18:14 -07:00
|
|
|
def register_lambda_action_connector(
|
|
|
|
name: str, fn: Callable[[TensorStructType, StateBatches, Dict], PolicyOutputType]
|
|
|
|
) -> Type[ActionConnector]:
|
|
|
|
"""A util to register any function transforming PolicyOutputType as an ActionConnector.
|
|
|
|
|
|
|
|
The only requirement is that fn should take actions, states, and fetches as input,
|
|
|
|
and return transformed actions, states, and fetches.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: Name of the resulting actor connector.
|
|
|
|
fn: The function that transforms PolicyOutputType.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A new ActionConnector class that transforms PolicyOutputType using fn.
|
|
|
|
"""
|
|
|
|
|
|
|
|
class LambdaActionConnector(ActionConnector):
|
2022-06-29 23:44:10 -07:00
|
|
|
def transform(
|
|
|
|
self, ac_data: ActionConnectorDataType
|
|
|
|
) -> ActionConnectorDataType:
|
2022-06-07 10:18:14 -07:00
|
|
|
assert isinstance(
|
|
|
|
ac_data.output, tuple
|
|
|
|
), "Action connector requires PolicyOutputType data."
|
|
|
|
|
|
|
|
actions, states, fetches = ac_data.output
|
|
|
|
return ActionConnectorDataType(
|
|
|
|
ac_data.env_id,
|
|
|
|
ac_data.agent_id,
|
|
|
|
fn(actions, states, fetches),
|
|
|
|
)
|
|
|
|
|
|
|
|
def to_config(self):
|
|
|
|
return name, None
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
|
|
|
return LambdaActionConnector(ctx)
|
|
|
|
|
|
|
|
LambdaActionConnector.__name__ = name
|
|
|
|
LambdaActionConnector.__qualname__ = name
|
|
|
|
|
|
|
|
register_connector(name, LambdaActionConnector)
|
|
|
|
|
|
|
|
return LambdaActionConnector
|
|
|
|
|
|
|
|
|
|
|
|
# Convert actions and states into numpy arrays if necessary.
|
2022-06-29 23:44:10 -07:00
|
|
|
ConvertToNumpyConnector = PublicAPI(stability="alpha")(
|
|
|
|
register_lambda_action_connector(
|
|
|
|
"ConvertToNumpyConnector",
|
|
|
|
lambda actions, states, fetches: (
|
|
|
|
convert_to_numpy(actions),
|
|
|
|
convert_to_numpy(states),
|
|
|
|
fetches,
|
|
|
|
),
|
2022-06-07 10:18:14 -07:00
|
|
|
),
|
|
|
|
)
|