ray/rllib/agents/pg/pg.py

35 lines
1.2 KiB
Python

from typing import Type
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
class PGTrainer(Trainer):
"""Policy Gradient (PG) Trainer.
Defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy losses for
TensorFlow and PyTorch.
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#pg
Only overrides the default config- and policy selectors
(`get_default_policy` and `get_default_config`). Utilizes
the default `execution_plan()` of `Trainer`.
"""
@override(Trainer)
def get_default_policy_class(self, config) -> Type[Policy]:
return PGTorchPolicy if config.get("framework") == "torch" \
else PGTFPolicy
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG