ray/rllib/agents/a3c/a2c.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

41 lines
1.4 KiB
Python

from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, \
validate_config, get_policy_class
from ray.rllib.optimizers import SyncSamplesOptimizer, MicrobatchOptimizer
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.utils import merge_dicts
A2C_DEFAULT_CONFIG = merge_dicts(
A3C_CONFIG,
{
"sample_batch_size": 20,
"min_iter_time_s": 10,
"sample_async": False,
# A2C supports microbatching, in which we accumulate gradients over
# batch of this size until the train batch size is reached. This allows
# training with batch sizes much larger than can fit in GPU memory.
# To enable, set this to a value less than the train batch size.
"microbatch_size": None,
},
)
def choose_policy_optimizer(workers, config):
if config["microbatch_size"]:
return MicrobatchOptimizer(
workers,
train_batch_size=config["train_batch_size"],
microbatch_size=config["microbatch_size"])
else:
return SyncSamplesOptimizer(
workers, train_batch_size=config["train_batch_size"])
A2CTrainer = build_trainer(
name="A2C",
default_config=A2C_DEFAULT_CONFIG,
default_policy=A3CTFPolicy,
get_policy_class=get_policy_class,
make_policy_optimizer=choose_policy_optimizer,
validate_config=validate_config)