mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00

* WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * LINT and fixes. MB-MPO and MAML not working yet. * wip * update * update * rmeove * remove dep * higher * Update requirements_rllib.txt * Update requirements_rllib.txt * relpos * no mbmpo Co-authored-by: Eric Liang <ekhliang@gmail.com>
60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
"""
|
|
Policy Gradient (PG)
|
|
====================
|
|
|
|
This file defines the distributed Trainer class for policy gradients.
|
|
See `pg_[tf|torch]_policy.py` for the definition of the policy loss.
|
|
|
|
Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#pg
|
|
"""
|
|
|
|
import logging
|
|
from typing import Optional, Type
|
|
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
|
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
|
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# yapf: disable
|
|
# __sphinx_doc_begin__
|
|
|
|
# Adds the following updates to the (base) `Trainer` config in
|
|
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
|
|
DEFAULT_CONFIG = with_common_config({
|
|
# No remote workers by default.
|
|
"num_workers": 0,
|
|
# Learning rate.
|
|
"lr": 0.0004,
|
|
})
|
|
|
|
# __sphinx_doc_end__
|
|
# yapf: enable
|
|
|
|
|
|
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
|
"""Policy class picker function. Class is chosen based on DL-framework.
|
|
|
|
Args:
|
|
config (TrainerConfigDict): The trainer's configuration dict.
|
|
|
|
Returns:
|
|
Optional[Type[Policy]]: The Policy class to use with PGTrainer.
|
|
If None, use `default_policy` provided in build_trainer().
|
|
"""
|
|
if config["framework"] == "torch":
|
|
return PGTorchPolicy
|
|
|
|
|
|
# Build a child class of `Trainer`, which uses the framework specific Policy
|
|
# determined in `get_policy_class()` above.
|
|
PGTrainer = build_trainer(
|
|
name="PG",
|
|
default_config=DEFAULT_CONFIG,
|
|
default_policy=PGTFPolicy,
|
|
get_policy_class=get_policy_class,
|
|
)
|