2020-02-19 21:18:45 +01:00
|
|
|
import logging
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
2019-06-03 06:49:24 +08:00
|
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
2018-06-27 02:30:15 -07:00
|
|
|
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
2020-03-02 15:16:37 -08:00
|
|
|
from ray.rllib.utils.experimental_dsl import (AsyncGradients, ApplyGradients,
|
|
|
|
StandardMetricsReporting)
|
2017-06-29 08:49:56 -07:00
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
# yapf: disable
|
2018-10-16 15:55:11 -07:00
|
|
|
# __sphinx_doc_begin__
|
2018-07-01 00:05:08 -07:00
|
|
|
DEFAULT_CONFIG = with_common_config({
|
2020-02-01 08:25:45 +02:00
|
|
|
# Should use a critic as a baseline (otherwise don't use value baseline;
|
|
|
|
# required for using GAE).
|
|
|
|
"use_critic": True,
|
|
|
|
# If true, use the Generalized Advantage Estimator (GAE)
|
|
|
|
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
|
|
|
"use_gae": True,
|
2017-11-30 00:22:25 -08:00
|
|
|
# Size of rollout batch
|
2020-03-14 12:05:04 -07:00
|
|
|
"rollout_fragment_length": 10,
|
2017-12-24 12:25:13 -08:00
|
|
|
# GAE(gamma) parameter
|
|
|
|
"lambda": 1.0,
|
|
|
|
# Max global norm for each gradient calculated by worker
|
|
|
|
"grad_clip": 40.0,
|
|
|
|
# Learning rate
|
|
|
|
"lr": 0.0001,
|
2018-08-23 17:49:10 -07:00
|
|
|
# Learning rate schedule
|
|
|
|
"lr_schedule": None,
|
2017-12-24 12:25:13 -08:00
|
|
|
# Value Function Loss coefficient
|
|
|
|
"vf_loss_coeff": 0.5,
|
|
|
|
# Entropy coefficient
|
2019-03-17 18:07:37 -07:00
|
|
|
"entropy_coeff": 0.01,
|
2018-08-20 15:28:03 -07:00
|
|
|
# Min time per iteration
|
|
|
|
"min_iter_time_s": 5,
|
2018-08-01 15:11:30 -07:00
|
|
|
# Workers sample async. Note that this increases the effective
|
2020-03-14 12:05:04 -07:00
|
|
|
# rollout_fragment_length by up to 5x due to async buffering of batches.
|
2018-07-01 00:05:08 -07:00
|
|
|
"sample_async": True,
|
2020-03-13 18:48:41 -07:00
|
|
|
# Use the execution plan API instead of policy optimizers.
|
|
|
|
"use_exec_api": True,
|
2018-07-01 00:05:08 -07:00
|
|
|
})
|
2018-10-16 15:55:11 -07:00
|
|
|
# __sphinx_doc_end__
|
2018-10-21 23:43:57 -07:00
|
|
|
# yapf: enable
|
2018-10-16 15:55:11 -07:00
|
|
|
|
2017-06-29 08:49:56 -07:00
|
|
|
|
2019-06-03 06:49:24 +08:00
|
|
|
def get_policy_class(config):
|
|
|
|
if config["use_pytorch"]:
|
|
|
|
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
|
|
|
A3CTorchPolicy
|
|
|
|
return A3CTorchPolicy
|
|
|
|
else:
|
|
|
|
return A3CTFPolicy
|
2018-07-01 00:05:08 -07:00
|
|
|
|
2017-10-10 12:49:42 -07:00
|
|
|
|
2019-06-03 06:49:24 +08:00
|
|
|
def validate_config(config):
|
|
|
|
if config["entropy_coeff"] < 0:
|
|
|
|
raise DeprecationWarning("entropy_coeff must be >= 0")
|
2019-06-25 22:06:36 -07:00
|
|
|
if config["sample_async"] and config["use_pytorch"]:
|
2020-02-19 21:18:45 +01:00
|
|
|
config["sample_async"] = False
|
|
|
|
logger.warning(
|
2019-06-25 22:06:36 -07:00
|
|
|
"The sample_async option is not supported with use_pytorch: "
|
|
|
|
"Multithreading can be lead to crashes if used with pytorch.")
|
2018-07-01 00:05:08 -07:00
|
|
|
|
2019-03-17 18:07:37 -07:00
|
|
|
|
2019-06-03 06:49:24 +08:00
|
|
|
def make_async_optimizer(workers, config):
|
|
|
|
return AsyncGradientsOptimizer(workers, **config["optimizer"])
|
2018-08-20 15:28:03 -07:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
|
2020-03-12 00:54:08 -07:00
|
|
|
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
|
|
|
def execution_plan(workers, config):
|
2020-03-02 15:16:37 -08:00
|
|
|
# For A3C, compute policy gradients remotely on the rollout workers.
|
|
|
|
grads = AsyncGradients(workers)
|
|
|
|
|
|
|
|
# Apply the gradients as they arrive. We set update_all to False so that
|
|
|
|
# only the worker sending the gradient is updated with new weights.
|
|
|
|
train_op = grads.for_each(ApplyGradients(workers, update_all=False))
|
|
|
|
|
|
|
|
return StandardMetricsReporting(train_op, workers, config)
|
|
|
|
|
|
|
|
|
2019-06-03 06:49:24 +08:00
|
|
|
A3CTrainer = build_trainer(
|
|
|
|
name="A3C",
|
|
|
|
default_config=DEFAULT_CONFIG,
|
|
|
|
default_policy=A3CTFPolicy,
|
|
|
|
get_policy_class=get_policy_class,
|
|
|
|
validate_config=validate_config,
|
2020-03-02 15:16:37 -08:00
|
|
|
make_policy_optimizer=make_async_optimizer,
|
2020-03-12 00:54:08 -07:00
|
|
|
execution_plan=execution_plan)
|