ray/rllib/agents/pg/pg.py
Sven Mika 19c8033df2
[RLlib] Fix most remaining RLlib algos for running with trajectory view API. (#12366)
* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* LINT and fixes.
MB-MPO and MAML not working yet.

* wip

* update

* update

* rmeove

* remove dep

* higher

* Update requirements_rllib.txt

* Update requirements_rllib.txt

* relpos

* no mbmpo

Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-12-01 17:41:10 -08:00

60 lines
1.7 KiB
Python

"""
Policy Gradient (PG)
====================
This file defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy loss.
Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#pg
"""
import logging
from typing import Optional, Type
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict
logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
# Adds the following updates to the (base) `Trainer` config in
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
})
# __sphinx_doc_end__
# yapf: enable
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Optional[Type[Policy]]: The Policy class to use with PGTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
return PGTorchPolicy
# Build a child class of `Trainer`, which uses the framework specific Policy
# determined in `get_policy_class()` above.
PGTrainer = build_trainer(
name="PG",
default_config=DEFAULT_CONFIG,
default_policy=PGTFPolicy,
get_policy_class=get_policy_class,
)