2019-05-18 00:23:11 -07:00
|
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
2020-01-02 19:08:03 -05:00
|
|
|
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
2018-07-01 00:05:08 -07:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
# yapf: disable
|
2018-10-16 15:55:11 -07:00
|
|
|
# __sphinx_doc_begin__
|
2018-07-01 00:05:08 -07:00
|
|
|
DEFAULT_CONFIG = with_common_config({
|
2020-01-02 19:08:03 -05:00
|
|
|
# No remote workers by default.
|
2018-07-01 00:05:08 -07:00
|
|
|
"num_workers": 0,
|
2020-01-02 19:08:03 -05:00
|
|
|
# Learning rate.
|
2018-07-01 00:05:08 -07:00
|
|
|
"lr": 0.0004,
|
|
|
|
})
|
2018-10-16 15:55:11 -07:00
|
|
|
# __sphinx_doc_end__
|
2018-10-21 23:43:57 -07:00
|
|
|
# yapf: enable
|
2018-10-16 15:55:11 -07:00
|
|
|
|
2018-07-01 00:05:08 -07:00
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
def get_policy_class(config):
|
2020-05-27 16:19:13 +02:00
|
|
|
if config["framework"] == "torch":
|
2020-01-02 19:08:03 -05:00
|
|
|
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
|
2019-05-18 00:23:11 -07:00
|
|
|
return PGTorchPolicy
|
|
|
|
else:
|
|
|
|
return PGTFPolicy
|
2018-07-01 00:05:08 -07:00
|
|
|
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
PGTrainer = build_trainer(
|
2019-06-03 06:49:24 +08:00
|
|
|
name="PG",
|
2019-05-18 00:23:11 -07:00
|
|
|
default_config=DEFAULT_CONFIG,
|
|
|
|
default_policy=PGTFPolicy,
|
2020-05-21 10:16:18 -07:00
|
|
|
get_policy_class=get_policy_class)
|