ray/rllib/agents/pg/pg.py
Kai Fricke 3e6ba5d6d2
Revert "Revert [RLlib] POC: PGTrainer class that works by sub-classing, not trainer_template.py." (#20285)
* Revert "Revert "[RLlib] POC: `PGTrainer` class that works by sub-classing, not `trainer_template.py`. (#20055)" (#20284)"
This reverts commit 246787cdd9.
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-16 12:26:47 +01:00

35 lines
1.2 KiB
Python

from typing import Type
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
class PGTrainer(Trainer):
"""Policy Gradient (PG) Trainer.
Defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy losses for
TensorFlow and PyTorch.
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#pg
Only overrides the default config- and policy selectors
(`get_default_policy` and `get_default_config`). Utilizes
the default `execution_plan()` of `Trainer`.
"""
@override(Trainer)
def get_default_policy_class(self, config) -> Type[Policy]:
return PGTorchPolicy if config.get("framework") == "torch" \
else PGTFPolicy
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG