2021-11-16 11:26:47 +00:00
|
|
|
from typing import Type
|
2020-08-20 17:05:57 +02:00
|
|
|
|
2021-11-16 11:26:47 +00:00
|
|
|
from ray.rllib.agents.trainer import Trainer
|
|
|
|
from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG
|
2020-01-02 19:08:03 -05:00
|
|
|
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
2020-08-20 17:05:57 +02:00
|
|
|
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
|
|
|
|
from ray.rllib.policy.policy import Policy
|
2021-11-16 11:26:47 +00:00
|
|
|
from ray.rllib.utils.annotations import override
|
2020-08-20 17:05:57 +02:00
|
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
2018-07-01 00:05:08 -07:00
|
|
|
|
2020-10-01 16:57:10 +02:00
|
|
|
|
2021-11-16 11:26:47 +00:00
|
|
|
class PGTrainer(Trainer):
|
|
|
|
"""Policy Gradient (PG) Trainer.
|
2020-08-20 17:05:57 +02:00
|
|
|
|
2021-11-16 11:26:47 +00:00
|
|
|
Defines the distributed Trainer class for policy gradients.
|
|
|
|
See `pg_[tf|torch]_policy.py` for the definition of the policy losses for
|
|
|
|
TensorFlow and PyTorch.
|
2020-08-20 17:05:57 +02:00
|
|
|
|
2021-11-16 11:26:47 +00:00
|
|
|
Detailed documentation:
|
|
|
|
https://docs.ray.io/en/master/rllib-algorithms.html#pg
|
2018-10-16 15:55:11 -07:00
|
|
|
|
2021-11-16 11:26:47 +00:00
|
|
|
Only overrides the default config- and policy selectors
|
|
|
|
(`get_default_policy` and `get_default_config`). Utilizes
|
|
|
|
the default `execution_plan()` of `Trainer`.
|
|
|
|
"""
|
2018-07-01 00:05:08 -07:00
|
|
|
|
2021-11-16 11:26:47 +00:00
|
|
|
@classmethod
|
|
|
|
@override(Trainer)
|
|
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
|
|
return DEFAULT_CONFIG
|
2021-11-23 23:01:05 +01:00
|
|
|
|
|
|
|
@override(Trainer)
|
|
|
|
def get_default_policy_class(self, config) -> Type[Policy]:
|
2022-01-29 18:41:57 -08:00
|
|
|
return PGTorchPolicy if config.get("framework") == "torch" else PGTFPolicy
|