mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
35 lines
892 B
Python
35 lines
892 B
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
|
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
|
|
|
|
# yapf: disable
|
|
# __sphinx_doc_begin__
|
|
DEFAULT_CONFIG = with_common_config({
|
|
# No remote workers by default
|
|
"num_workers": 0,
|
|
# Learning rate
|
|
"lr": 0.0004,
|
|
# Use PyTorch as backend
|
|
"use_pytorch": False,
|
|
})
|
|
# __sphinx_doc_end__
|
|
# yapf: enable
|
|
|
|
|
|
def get_policy_class(config):
|
|
if config["use_pytorch"]:
|
|
from ray.rllib.agents.pg.torch_pg_policy import PGTorchPolicy
|
|
return PGTorchPolicy
|
|
else:
|
|
return PGTFPolicy
|
|
|
|
|
|
PGTrainer = build_trainer(
|
|
name="PG",
|
|
default_config=DEFAULT_CONFIG,
|
|
default_policy=PGTFPolicy,
|
|
get_policy_class=get_policy_class)
|