mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
from ray.rllib.models.catalog import ModelCatalog
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
|
|
|
|
POLICY_SCOPE = "func"
|
|
TARGET_POLICY_SCOPE = "target_func"
|
|
|
|
|
|
def make_appo_models(policy) -> ModelV2:
|
|
"""Builds model and target model for APPO.
|
|
|
|
Returns:
|
|
ModelV2: The Model for the Policy to use.
|
|
Note: The target model will not be returned, just assigned to
|
|
`policy.target_model`.
|
|
"""
|
|
# Get the num_outputs for the following model construction calls.
|
|
_, logit_dim = ModelCatalog.get_action_dist(
|
|
policy.action_space, policy.config["model"]
|
|
)
|
|
|
|
# Construct the (main) model.
|
|
policy.model = ModelCatalog.get_model_v2(
|
|
policy.observation_space,
|
|
policy.action_space,
|
|
logit_dim,
|
|
policy.config["model"],
|
|
name=POLICY_SCOPE,
|
|
framework=policy.framework,
|
|
)
|
|
policy.model_variables = policy.model.variables()
|
|
|
|
# Construct the target model.
|
|
policy.target_model = ModelCatalog.get_model_v2(
|
|
policy.observation_space,
|
|
policy.action_space,
|
|
logit_dim,
|
|
policy.config["model"],
|
|
name=TARGET_POLICY_SCOPE,
|
|
framework=policy.framework,
|
|
)
|
|
policy.target_model_variables = policy.target_model.variables()
|
|
|
|
# Return only the model (not the target model).
|
|
return policy.model
|