from typing import Type from ray.rllib.agents.trainer import Trainer from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import TrainerConfigDict class PGTrainer(Trainer): """Policy Gradient (PG) Trainer. Defines the distributed Trainer class for policy gradients. See `pg_[tf|torch]_policy.py` for the definition of the policy losses for TensorFlow and PyTorch. Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#pg Only overrides the default config- and policy selectors (`get_default_policy` and `get_default_config`). Utilizes the default `execution_plan()` of `Trainer`. """ @override(Trainer) def get_default_policy_class(self, config) -> Type[Policy]: return PGTorchPolicy if config.get("framework") == "torch" \ else PGTFPolicy @classmethod @override(Trainer) def get_default_config(cls) -> TrainerConfigDict: return DEFAULT_CONFIG