mirror of
https://github.com/vale981/ray
synced 2025-03-09 04:46:38 -04:00
58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
![]() |
import gym
|
||
|
from typing import Any, List
|
||
|
|
||
|
from ray.rllib.connectors.connector import (
|
||
|
ActionConnector,
|
||
|
Connector,
|
||
|
ConnectorContext,
|
||
|
ConnectorPipeline,
|
||
|
get_connector,
|
||
|
register_connector,
|
||
|
)
|
||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||
|
from ray.rllib.utils.typing import (
|
||
|
ActionConnectorDataType,
|
||
|
TrainerConfigDict,
|
||
|
)
|
||
|
|
||
|
|
||
|
@DeveloperAPI
|
||
|
class ActionConnectorPipeline(ActionConnector, ConnectorPipeline):
|
||
|
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
|
||
|
super().__init__(ctx)
|
||
|
self.connectors = connectors
|
||
|
|
||
|
def is_training(self, is_training: bool):
|
||
|
self.is_training = is_training
|
||
|
for c in self.connectors:
|
||
|
c.is_training(is_training)
|
||
|
|
||
|
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||
|
for c in self.connectors:
|
||
|
ac_data = c(ac_data)
|
||
|
return ac_data
|
||
|
|
||
|
def to_config(self):
|
||
|
return ActionConnectorPipeline.__name__, [
|
||
|
c.to_config() for c in self.connectors
|
||
|
]
|
||
|
|
||
|
@staticmethod
|
||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||
|
assert (
|
||
|
type(params) == list
|
||
|
), "ActionConnectorPipeline takes a list of connector params."
|
||
|
connectors = [get_connector(ctx, name, subparams) for name, subparams in params]
|
||
|
return ActionConnectorPipeline(ctx, connectors)
|
||
|
|
||
|
|
||
|
register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline)
|
||
|
|
||
|
|
||
|
@DeveloperAPI
|
||
|
def get_action_connectors_from_trainer_config(
|
||
|
config: TrainerConfigDict, action_space: gym.Space
|
||
|
) -> ActionConnectorPipeline:
|
||
|
connectors = []
|
||
|
return ActionConnectorPipeline(connectors)
|