mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
106 lines
3.4 KiB
Python
106 lines
3.4 KiB
Python
from typing import Type
|
|
|
|
from ray.rllib.agents.trainer import Trainer
|
|
from ray.rllib.agents.trainer_config import TrainerConfig
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.deprecation import Deprecated
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
|
|
|
|
class PGConfig(TrainerConfig):
|
|
"""Defines a configuration class from which a PG Trainer can be built.
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.pg import PGConfig
|
|
>>> config = PGConfig().training(lr=0.01).resources(num_gpus=1)
|
|
>>> print(config.to_dict())
|
|
>>> # Build a Trainer object from the config and run 1 training iteration.
|
|
>>> trainer = config.build(env="CartPole-v1")
|
|
>>> trainer.train()
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.pg import PGConfig
|
|
>>> from ray import tune
|
|
>>> config = PGConfig()
|
|
>>> # Print out some default values.
|
|
>>> print(config.lr) # doctest: +SKIP
|
|
0.0004
|
|
>>> # Update the config object.
|
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
|
>>> # Set the config object's env.
|
|
>>> config.environment(env="CartPole-v1")
|
|
>>> # Use to_dict() to get the old-style python config dict
|
|
>>> # when running with tune.
|
|
>>> tune.run(
|
|
... "PG",
|
|
... stop={"episode_reward_mean": 200},
|
|
... config=config.to_dict(),
|
|
... )
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initializes a PGConfig instance."""
|
|
super().__init__(trainer_class=PG)
|
|
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
# Override some of TrainerConfig's default values with PG-specific values.
|
|
self.num_workers = 0
|
|
self.lr = 0.0004
|
|
self._disable_preprocessor_api = True
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
|
|
class PG(Trainer):
|
|
"""Policy Gradient (PG) Trainer.
|
|
|
|
Defines the distributed Trainer class for policy gradients.
|
|
See `pg_[tf|torch]_policy.py` for the definition of the policy losses for
|
|
TensorFlow and PyTorch.
|
|
|
|
Detailed documentation:
|
|
https://docs.ray.io/en/master/rllib-algorithms.html#pg
|
|
|
|
Only overrides the default config- and policy selectors
|
|
(`get_default_policy` and `get_default_config`). Utilizes
|
|
the default `execution_plan()` of `Trainer`.
|
|
"""
|
|
|
|
@classmethod
|
|
@override(Trainer)
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
return PGConfig().to_dict()
|
|
|
|
@override(Trainer)
|
|
def get_default_policy_class(self, config) -> Type[Policy]:
|
|
if config["framework"] == "torch":
|
|
from ray.rllib.algorithms.pg.pg_torch_policy import PGTorchPolicy
|
|
|
|
return PGTorchPolicy
|
|
elif config["framework"] == "tf":
|
|
from ray.rllib.algorithms.pg.pg_tf_policy import PGStaticGraphTFPolicy
|
|
|
|
return PGStaticGraphTFPolicy
|
|
else:
|
|
from ray.rllib.algorithms.pg.pg_tf_policy import PGEagerTFPolicy
|
|
|
|
return PGEagerTFPolicy
|
|
|
|
|
|
# Deprecated: Use ray.rllib.algorithms.pg.PGConfig instead!
|
|
class _deprecated_default_config(dict):
|
|
def __init__(self):
|
|
super().__init__(PGConfig().to_dict())
|
|
|
|
@Deprecated(
|
|
old="ray.rllib.algorithms.pg.default_config::DEFAULT_CONFIG",
|
|
new="ray.rllib.algorithms.pg.pg::PGConfig(...)",
|
|
error=False,
|
|
)
|
|
def __getitem__(self, item):
|
|
return super().__getitem__(item)
|
|
|
|
|
|
DEFAULT_CONFIG = _deprecated_default_config()
|