mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
29 lines
741 B
Python
29 lines
741 B
Python
from ray.rllib.agents.trainer import with_common_config
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
|
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
|
|
|
# yapf: disable
|
|
# __sphinx_doc_begin__
|
|
DEFAULT_CONFIG = with_common_config({
|
|
# No remote workers by default.
|
|
"num_workers": 0,
|
|
# Learning rate.
|
|
"lr": 0.0004,
|
|
})
|
|
# __sphinx_doc_end__
|
|
# yapf: enable
|
|
|
|
|
|
def get_policy_class(config):
|
|
if config["framework"] == "torch":
|
|
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
|
|
return PGTorchPolicy
|
|
else:
|
|
return PGTFPolicy
|
|
|
|
|
|
PGTrainer = build_trainer(
|
|
name="PG",
|
|
default_config=DEFAULT_CONFIG,
|
|
default_policy=PGTFPolicy,
|
|
get_policy_class=get_policy_class)
|