mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
from typing import Type
|
|
|
|
from ray.rllib.agents.trainer import Trainer
|
|
from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG
|
|
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
|
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
|
|
|
|
class PGTrainer(Trainer):
|
|
"""Policy Gradient (PG) Trainer.
|
|
|
|
Defines the distributed Trainer class for policy gradients.
|
|
See `pg_[tf|torch]_policy.py` for the definition of the policy losses for
|
|
TensorFlow and PyTorch.
|
|
|
|
Detailed documentation:
|
|
https://docs.ray.io/en/master/rllib-algorithms.html#pg
|
|
|
|
Only overrides the default config- and policy selectors
|
|
(`get_default_policy` and `get_default_config`). Utilizes
|
|
the default `execution_plan()` of `Trainer`.
|
|
"""
|
|
|
|
@classmethod
|
|
@override(Trainer)
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
return DEFAULT_CONFIG
|
|
|
|
@override(Trainer)
|
|
def get_default_policy_class(self, config) -> Type[Policy]:
|
|
return PGTorchPolicy if config.get("framework") == "torch" else PGTFPolicy
|