mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
import logging
|
|
from typing import Any, Tuple, TYPE_CHECKING
|
|
|
|
from ray.rllib.connectors.action.clip import ClipActionsConnector
|
|
from ray.rllib.connectors.action.immutable import ImmutableActionsConnector
|
|
from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector
|
|
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
|
|
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
|
|
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
|
|
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
|
|
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
|
|
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
|
|
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
|
|
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
|
|
from ray.rllib.connectors.connector import Connector, ConnectorContext, get_connector
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
from ray.util.annotations import PublicAPI
|
|
|
|
if TYPE_CHECKING:
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@PublicAPI(stability="alpha")
|
|
def get_agent_connectors_from_config(
|
|
ctx: ConnectorContext,
|
|
config: TrainerConfigDict,
|
|
) -> AgentConnectorPipeline:
|
|
connectors = []
|
|
|
|
if config["clip_rewards"] is True:
|
|
connectors.append(ClipRewardAgentConnector(ctx, sign=True))
|
|
elif type(config["clip_rewards"]) == float:
|
|
connectors.append(
|
|
ClipRewardAgentConnector(ctx, limit=abs(config["clip_rewards"]))
|
|
)
|
|
|
|
if not config["_disable_preprocessor_api"]:
|
|
connectors.append(ObsPreprocessorConnector(ctx))
|
|
|
|
connectors.extend(
|
|
[
|
|
StateBufferConnector(ctx),
|
|
ViewRequirementAgentConnector(ctx),
|
|
FlattenDataAgentConnector(ctx), # Creates batch dimension.
|
|
]
|
|
)
|
|
|
|
return AgentConnectorPipeline(ctx, connectors)
|
|
|
|
|
|
@PublicAPI(stability="alpha")
|
|
def get_action_connectors_from_config(
|
|
ctx: ConnectorContext,
|
|
config: TrainerConfigDict,
|
|
) -> ActionConnectorPipeline:
|
|
"""Default list of action connectors to use for a new policy.
|
|
|
|
Args:
|
|
ctx: context used to create connectors.
|
|
config: trainer config.
|
|
"""
|
|
connectors = [ConvertToNumpyConnector(ctx)]
|
|
if config.get("normalize_actions", False):
|
|
connectors.append(NormalizeActionsConnector(ctx))
|
|
if config.get("clip_actions", False):
|
|
connectors.append(ClipActionsConnector(ctx))
|
|
connectors.append(ImmutableActionsConnector(ctx))
|
|
return ActionConnectorPipeline(ctx, connectors)
|
|
|
|
|
|
@PublicAPI(stability="alpha")
|
|
def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict):
|
|
"""Util to create agent and action connectors for a Policy.
|
|
|
|
Args:
|
|
policy: Policy instance.
|
|
config: Trainer config dict.
|
|
"""
|
|
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
|
|
|
|
policy.agent_connectors = get_agent_connectors_from_config(ctx, config)
|
|
policy.action_connectors = get_action_connectors_from_config(ctx, config)
|
|
|
|
logger.info("Using connectors:")
|
|
logger.info(policy.agent_connectors.__str__(indentation=4))
|
|
logger.info(policy.action_connectors.__str__(indentation=4))
|
|
|
|
|
|
@PublicAPI(stability="alpha")
|
|
def restore_connectors_for_policy(
|
|
policy: "Policy", connector_config: Tuple[str, Tuple[Any]]
|
|
) -> Connector:
|
|
"""Util to create connector for a Policy based on serialized config.
|
|
|
|
Args:
|
|
policy: Policy instance.
|
|
connector_config: Serialized connector config.
|
|
"""
|
|
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
|
|
name, params = connector_config
|
|
return get_connector(ctx, name, params)
|