ray/rllib/agents/pg/pg.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

34 lines
1.2 KiB
Python

from typing import Type
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
class PGTrainer(Trainer):
"""Policy Gradient (PG) Trainer.
Defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy losses for
TensorFlow and PyTorch.
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#pg
Only overrides the default config- and policy selectors
(`get_default_policy` and `get_default_config`). Utilizes
the default `execution_plan()` of `Trainer`.
"""
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
@override(Trainer)
def get_default_policy_class(self, config) -> Type[Policy]:
return PGTorchPolicy if config.get("framework") == "torch" else PGTFPolicy