ray/rllib/agents/impala/impala.py
Eric Liang dd70720578
[rllib] Rename sample_batch_size => rollout_fragment_length (#7503)
* bulk rename

* deprecation warn

* update doc

* update fig

* line length

* rename

* make pytest comptaible

* fix test

* fi sys

* rename

* wip

* fix more

* lint

* update svg

* comments

* lint

* fix use of batch steps
2020-03-14 12:05:04 -07:00

168 lines
6.4 KiB
Python

from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.impala.vtrace_policy import VTraceTFPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.optimizers import AsyncSamplesOptimizer
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
from ray.rllib.utils.annotations import override
from ray.tune.trainable import Trainable
from ray.tune.resources import Resources
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# V-trace params (see vtrace.py).
"vtrace": True,
"vtrace_clip_rho_threshold": 1.0,
"vtrace_clip_pg_rho_threshold": 1.0,
# System params.
#
# == Overview of data flow in IMPALA ==
# 1. Policy evaluation in parallel across `num_workers` actors produces
# batches of size `rollout_fragment_length * num_envs_per_worker`.
# 2. If enabled, the replay buffer stores and produces batches of size
# `rollout_fragment_length * num_envs_per_worker`.
# 3. If enabled, the minibatch ring buffer stores and replays batches of
# size `train_batch_size` up to `num_sgd_iter` times per batch.
# 4. The learner thread executes data parallel SGD across `num_gpus` GPUs
# on batches of size `train_batch_size`.
#
"rollout_fragment_length": 50,
"train_batch_size": 500,
"min_iter_time_s": 10,
"num_workers": 2,
# number of GPUs the learner should use.
"num_gpus": 1,
# set >1 to load data into GPUs in parallel. Increases GPU memory usage
# proportionally with the number of buffers.
"num_data_loader_buffers": 1,
# how many train batches should be retained for minibatching. This conf
# only has an effect if `num_sgd_iter > 1`.
"minibatch_buffer_size": 1,
# number of passes to make over each train batch
"num_sgd_iter": 1,
# set >0 to enable experience replay. Saved samples will be replayed with
# a p:1 proportion to new data samples.
"replay_proportion": 0.0,
# number of sample batches to store for replay. The number of transitions
# saved total will be (replay_buffer_num_slots * rollout_fragment_length).
"replay_buffer_num_slots": 0,
# max queue size for train batches feeding into the learner
"learner_queue_size": 16,
# wait for train batches to be available in minibatch buffer queue
# this many seconds. This may need to be increased e.g. when training
# with a slow environment
"learner_queue_timeout": 300,
# level of queuing for sampling.
"max_sample_requests_in_flight_per_worker": 2,
# max number of workers to broadcast one set of weights to
"broadcast_interval": 1,
# use intermediate actors for multi-level aggregation. This can make sense
# if ingesting >2GB/s of samples, or if the data requires decompression.
"num_aggregation_workers": 0,
# Learning params.
"grad_clip": 40.0,
# either "adam" or "rmsprop"
"opt_type": "adam",
"lr": 0.0005,
"lr_schedule": None,
# rmsprop considered
"decay": 0.99,
"momentum": 0.0,
"epsilon": 0.1,
# balancing the three losses
"vf_loss_coeff": 0.5,
"entropy_coeff": 0.01,
"entropy_coeff_schedule": None,
# use fake (infinite speed) sampler for testing
"_fake_sampler": False,
})
# __sphinx_doc_end__
# yapf: enable
def choose_policy(config):
if config["vtrace"]:
return VTraceTFPolicy
else:
return A3CTFPolicy
def validate_config(config):
# PyTorch check.
if config["use_pytorch"]:
raise ValueError(
"IMPALA does not support PyTorch yet! Use tf instead.")
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
def defer_make_workers(trainer, env_creator, policy, config):
# Defer worker creation to after the optimizer has been created.
return trainer._make_workers(env_creator, policy, config, 0)
def make_aggregators_and_optimizer(workers, config):
if config["num_aggregation_workers"] > 0:
# Create co-located aggregator actors first for placement pref
aggregators = TreeAggregator.precreate_aggregators(
config["num_aggregation_workers"])
else:
aggregators = None
workers.add_workers(config["num_workers"])
optimizer = AsyncSamplesOptimizer(
workers,
lr=config["lr"],
num_gpus=config["num_gpus"],
rollout_fragment_length=config["rollout_fragment_length"],
train_batch_size=config["train_batch_size"],
replay_buffer_num_slots=config["replay_buffer_num_slots"],
replay_proportion=config["replay_proportion"],
num_data_loader_buffers=config["num_data_loader_buffers"],
max_sample_requests_in_flight_per_worker=config[
"max_sample_requests_in_flight_per_worker"],
broadcast_interval=config["broadcast_interval"],
num_sgd_iter=config["num_sgd_iter"],
minibatch_buffer_size=config["minibatch_buffer_size"],
num_aggregation_workers=config["num_aggregation_workers"],
learner_queue_size=config["learner_queue_size"],
learner_queue_timeout=config["learner_queue_timeout"],
**config["optimizer"])
if aggregators:
# Assign the pre-created aggregators to the optimizer
optimizer.aggregator.init(aggregators)
return optimizer
class OverrideDefaultResourceRequest:
@classmethod
@override(Trainable)
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
Trainer._validate_config(cf)
return Resources(
cpu=cf["num_cpus_for_driver"],
gpu=cf["num_gpus"],
memory=cf["memory"],
object_store_memory=cf["object_store_memory"],
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"] +
cf["num_aggregation_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"],
extra_memory=cf["memory_per_worker"] * cf["num_workers"],
extra_object_store_memory=cf["object_store_memory_per_worker"] *
cf["num_workers"])
ImpalaTrainer = build_trainer(
name="IMPALA",
default_config=DEFAULT_CONFIG,
default_policy=VTraceTFPolicy,
validate_config=validate_config,
get_policy_class=choose_policy,
make_workers=defer_make_workers,
make_policy_optimizer=make_aggregators_and_optimizer,
mixins=[OverrideDefaultResourceRequest])