ray/rllib/algorithms/pg/pg.py

106 lines
3.4 KiB
Python

from typing import Type
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import AlgorithmConfigDict
class PGConfig(AlgorithmConfig):
"""Defines a configuration class from which a PG Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.pg import PGConfig
>>> config = PGConfig().training(lr=0.01).resources(num_gpus=1)
>>> print(config.to_dict())
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
Example:
>>> from ray.rllib.algorithms.pg import PGConfig
>>> from ray import tune
>>> config = PGConfig()
>>> # Print out some default values.
>>> print(config.lr) # doctest: +SKIP
0.0004
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "PG",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self):
"""Initializes a PGConfig instance."""
super().__init__(algo_class=PG)
# fmt: off
# __sphinx_doc_begin__
# Override some of AlgorithmConfig's default values with PG-specific values.
self.num_workers = 0
self.lr = 0.0004
self._disable_preprocessor_api = True
# __sphinx_doc_end__
# fmt: on
class PG(Algorithm):
"""Policy Gradient (PG) Trainer.
Defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy losses for
TensorFlow and PyTorch.
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#pg
Only overrides the default config- and policy selectors
(`get_default_policy` and `get_default_config`). Utilizes
the default `execution_plan()` of `Trainer`.
"""
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return PGConfig().to_dict()
@override(Algorithm)
def get_default_policy_class(self, config) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.algorithms.pg.pg_torch_policy import PGTorchPolicy
return PGTorchPolicy
elif config["framework"] == "tf":
from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy
return PGTF1Policy
else:
from ray.rllib.algorithms.pg.pg_tf_policy import PGTF2Policy
return PGTF2Policy
# Deprecated: Use ray.rllib.algorithms.pg.PGConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(PGConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.pg.default_config::DEFAULT_CONFIG",
new="ray.rllib.algorithms.pg.pg::PGConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()