mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00

* Remove all __future__ imports from RLlib. * Remove (object) again from tf_run_builder.py::TFRunBuilder. * Fix 2xLINT warnings. * Fix broken appo_policy import (must be appo_tf_policy) * Remove future imports from all other ray files (not just RLlib). * Remove future imports from all other ray files (not just RLlib). * Remove future import blocks that contain `unicode_literals` as well. Revert appo_tf_policy.py to appo_policy.py (belongs to another PR). * Add two empty lines before Schedule class. * Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, \
|
|
validate_config, get_policy_class
|
|
from ray.rllib.optimizers import SyncSamplesOptimizer, MicrobatchOptimizer
|
|
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
|
from ray.rllib.utils import merge_dicts
|
|
|
|
A2C_DEFAULT_CONFIG = merge_dicts(
|
|
A3C_CONFIG,
|
|
{
|
|
"sample_batch_size": 20,
|
|
"min_iter_time_s": 10,
|
|
"sample_async": False,
|
|
|
|
# A2C supports microbatching, in which we accumulate gradients over
|
|
# batch of this size until the train batch size is reached. This allows
|
|
# training with batch sizes much larger than can fit in GPU memory.
|
|
# To enable, set this to a value less than the train batch size.
|
|
"microbatch_size": None,
|
|
},
|
|
)
|
|
|
|
|
|
def choose_policy_optimizer(workers, config):
|
|
if config["microbatch_size"]:
|
|
return MicrobatchOptimizer(
|
|
workers,
|
|
train_batch_size=config["train_batch_size"],
|
|
microbatch_size=config["microbatch_size"])
|
|
else:
|
|
return SyncSamplesOptimizer(
|
|
workers, train_batch_size=config["train_batch_size"])
|
|
|
|
|
|
A2CTrainer = build_trainer(
|
|
name="A2C",
|
|
default_config=A2C_DEFAULT_CONFIG,
|
|
default_policy=A3CTFPolicy,
|
|
get_policy_class=get_policy_class,
|
|
make_policy_optimizer=choose_policy_optimizer,
|
|
validate_config=validate_config)
|