ray/rllib/agents/pg/pg.py

57 lines
1.7 KiB
Python

"""
Policy Gradient (PG)
====================
This file defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy loss.
Detailed documentation: https://docs.ray.io/en/latest/rllib-algorithms.html#pg
"""
from typing import Optional, Type
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict
# yapf: disable
# __sphinx_doc_begin__
# Adds the following updates to the (base) `Trainer` config in
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
})
# __sphinx_doc_end__
# yapf: enable
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Optional[Type[Policy]]: The Policy class to use with PGTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
return PGTorchPolicy
# Build a child class of `Trainer`, which uses the framework specific Policy
# determined in `get_policy_class()` above.
PGTrainer = build_trainer(
name="PG",
default_config=DEFAULT_CONFIG,
default_policy=PGTFPolicy,
get_policy_class=get_policy_class,
)