ray/rllib/algorithms/appo/utils.py

45 lines
1.3 KiB
Python

from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
POLICY_SCOPE = "func"
TARGET_POLICY_SCOPE = "target_func"
def make_appo_models(policy) -> ModelV2:
"""Builds model and target model for APPO.
Returns:
ModelV2: The Model for the Policy to use.
Note: The target model will not be returned, just assigned to
`policy.target_model`.
"""
# Get the num_outputs for the following model construction calls.
_, logit_dim = ModelCatalog.get_action_dist(
policy.action_space, policy.config["model"]
)
# Construct the (main) model.
policy.model = ModelCatalog.get_model_v2(
policy.observation_space,
policy.action_space,
logit_dim,
policy.config["model"],
name=POLICY_SCOPE,
framework=policy.framework,
)
policy.model_variables = policy.model.variables()
# Construct the target model.
policy.target_model = ModelCatalog.get_model_v2(
policy.observation_space,
policy.action_space,
logit_dim,
policy.config["model"],
name=TARGET_POLICY_SCOPE,
framework=policy.framework,
)
policy.target_model_variables = policy.target_model.variables()
# Return only the model (not the target model).
return policy.model