from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 POLICY_SCOPE = "func" TARGET_POLICY_SCOPE = "target_func" def make_appo_models(policy) -> ModelV2: """Builds model and target model for APPO. Returns: ModelV2: The Model for the Policy to use. Note: The target model will not be returned, just assigned to `policy.target_model`. """ # Get the num_outputs for the following model construction calls. _, logit_dim = ModelCatalog.get_action_dist( policy.action_space, policy.config["model"] ) # Construct the (main) model. policy.model = ModelCatalog.get_model_v2( policy.observation_space, policy.action_space, logit_dim, policy.config["model"], name=POLICY_SCOPE, framework=policy.framework, ) policy.model_variables = policy.model.variables() # Construct the target model. policy.target_model = ModelCatalog.get_model_v2( policy.observation_space, policy.action_space, logit_dim, policy.config["model"], name=TARGET_POLICY_SCOPE, framework=policy.framework, ) policy.target_model_variables = policy.target_model.variables() # Return only the model (not the target model). return policy.model