mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Report timesteps_this_iter to Tune, so it can track/checkpoint/restore total timesteps trained. * Trigger Build * lint
360 lines
15 KiB
Python
360 lines
15 KiB
Python
import logging
|
|
|
|
import ray
|
|
from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
|
from ray.rllib.execution.learner_thread import LearnerThread
|
|
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
|
|
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
|
|
from ray.rllib.execution.common import (STEPS_TRAINED_COUNTER,
|
|
STEPS_TRAINED_THIS_ITER_COUNTER,
|
|
_get_global_vars, _get_shared_metrics)
|
|
from ray.rllib.execution.replay_ops import MixInReplay
|
|
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
|
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
|
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
|
from ray.tune.trainable import Trainable
|
|
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# yapf: disable
|
|
# __sphinx_doc_begin__
|
|
DEFAULT_CONFIG = with_common_config({
|
|
# V-trace params (see vtrace_tf/torch.py).
|
|
"vtrace": True,
|
|
"vtrace_clip_rho_threshold": 1.0,
|
|
"vtrace_clip_pg_rho_threshold": 1.0,
|
|
# System params.
|
|
#
|
|
# == Overview of data flow in IMPALA ==
|
|
# 1. Policy evaluation in parallel across `num_workers` actors produces
|
|
# batches of size `rollout_fragment_length * num_envs_per_worker`.
|
|
# 2. If enabled, the replay buffer stores and produces batches of size
|
|
# `rollout_fragment_length * num_envs_per_worker`.
|
|
# 3. If enabled, the minibatch ring buffer stores and replays batches of
|
|
# size `train_batch_size` up to `num_sgd_iter` times per batch.
|
|
# 4. The learner thread executes data parallel SGD across `num_gpus` GPUs
|
|
# on batches of size `train_batch_size`.
|
|
#
|
|
"rollout_fragment_length": 50,
|
|
"train_batch_size": 500,
|
|
"min_iter_time_s": 10,
|
|
"num_workers": 2,
|
|
# Number of GPUs the learner should use.
|
|
"num_gpus": 1,
|
|
# For each stack of multi-GPU towers, how many slots should we reserve for
|
|
# parallel data loading? Set this to >1 to load data into GPUs in
|
|
# parallel. This will increase GPU memory usage proportionally with the
|
|
# number of stacks.
|
|
# Example:
|
|
# 2 GPUs and `num_multi_gpu_tower_stacks=3`:
|
|
# - One tower stack consists of 2 GPUs, each with a copy of the
|
|
# model/graph.
|
|
# - Each of the stacks will create 3 slots for batch data on each of its
|
|
# GPUs, increasing memory requirements on each GPU by 3x.
|
|
# - This enables us to preload data into these stacks while another stack
|
|
# is performing gradient calculations.
|
|
"num_multi_gpu_tower_stacks": 1,
|
|
# How many train batches should be retained for minibatching. This conf
|
|
# only has an effect if `num_sgd_iter > 1`.
|
|
"minibatch_buffer_size": 1,
|
|
# Number of passes to make over each train batch.
|
|
"num_sgd_iter": 1,
|
|
# Set >0 to enable experience replay. Saved samples will be replayed with
|
|
# a p:1 proportion to new data samples.
|
|
"replay_proportion": 0.0,
|
|
# Number of sample batches to store for replay. The number of transitions
|
|
# saved total will be (replay_buffer_num_slots * rollout_fragment_length).
|
|
"replay_buffer_num_slots": 0,
|
|
# Max queue size for train batches feeding into the learner.
|
|
"learner_queue_size": 16,
|
|
# Wait for train batches to be available in minibatch buffer queue
|
|
# this many seconds. This may need to be increased e.g. when training
|
|
# with a slow environment.
|
|
"learner_queue_timeout": 300,
|
|
# Level of queuing for sampling.
|
|
"max_sample_requests_in_flight_per_worker": 2,
|
|
# Max number of workers to broadcast one set of weights to.
|
|
"broadcast_interval": 1,
|
|
# Use n (`num_aggregation_workers`) extra Actors for multi-level
|
|
# aggregation of the data produced by the m RolloutWorkers
|
|
# (`num_workers`). Note that n should be much smaller than m.
|
|
# This can make sense if ingesting >2GB/s of samples, or if
|
|
# the data requires decompression.
|
|
"num_aggregation_workers": 0,
|
|
|
|
# Learning params.
|
|
"grad_clip": 40.0,
|
|
# Either "adam" or "rmsprop".
|
|
"opt_type": "adam",
|
|
"lr": 0.0005,
|
|
"lr_schedule": None,
|
|
# `opt_type=rmsprop` settings.
|
|
"decay": 0.99,
|
|
"momentum": 0.0,
|
|
"epsilon": 0.1,
|
|
# Balancing the three losses.
|
|
"vf_loss_coeff": 0.5,
|
|
"entropy_coeff": 0.01,
|
|
"entropy_coeff_schedule": None,
|
|
# Set this to true to have two separate optimizers optimize the policy-
|
|
# and value networks.
|
|
"_separate_vf_optimizer": False,
|
|
# If _separate_vf_optimizer is True, define separate learning rate
|
|
# for the value network.
|
|
"_lr_vf": 0.0005,
|
|
|
|
# Callback for APPO to use to update KL, target network periodically.
|
|
# The input to the callback is the learner fetches dict.
|
|
"after_train_step": None,
|
|
|
|
# DEPRECATED:
|
|
"num_data_loader_buffers": DEPRECATED_VALUE,
|
|
})
|
|
# __sphinx_doc_end__
|
|
# yapf: enable
|
|
|
|
|
|
class OverrideDefaultResourceRequest:
|
|
@classmethod
|
|
@override(Trainable)
|
|
def default_resource_request(cls, config):
|
|
cf = dict(cls._default_config, **config)
|
|
|
|
eval_config = cf["evaluation_config"]
|
|
|
|
# Return PlacementGroupFactory containing all needed resources
|
|
# (already properly defined as device bundles).
|
|
return PlacementGroupFactory(
|
|
bundles=[{
|
|
# Driver + Aggregation Workers:
|
|
# Force to be on same node to maximize data bandwidth
|
|
# between aggregation workers and the learner (driver).
|
|
# Aggregation workers tree-aggregate experiences collected
|
|
# from RolloutWorkers (n rollout workers map to m
|
|
# aggregation workers, where m < n) and always use 1 CPU
|
|
# each.
|
|
"CPU": cf["num_cpus_for_driver"] +
|
|
cf["num_aggregation_workers"],
|
|
"GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
|
|
}] + [
|
|
{
|
|
# RolloutWorkers.
|
|
"CPU": cf["num_cpus_per_worker"],
|
|
"GPU": cf["num_gpus_per_worker"],
|
|
} for _ in range(cf["num_workers"])
|
|
] + ([
|
|
{
|
|
# Evaluation (remote) workers.
|
|
# Note: The local eval worker is located on the driver CPU.
|
|
"CPU": eval_config.get("num_cpus_per_worker",
|
|
cf["num_cpus_per_worker"]),
|
|
"GPU": eval_config.get("num_gpus_per_worker",
|
|
cf["num_gpus_per_worker"]),
|
|
} for _ in range(cf["evaluation_num_workers"])
|
|
] if cf["evaluation_interval"] else []),
|
|
strategy=config.get("placement_strategy", "PACK"))
|
|
|
|
|
|
def make_learner_thread(local_worker, config):
|
|
if not config["simple_optimizer"]:
|
|
logger.info(
|
|
"Enabling multi-GPU mode, {} GPUs, {} parallel tower-stacks".
|
|
format(config["num_gpus"], config["num_multi_gpu_tower_stacks"]))
|
|
num_stacks = config["num_multi_gpu_tower_stacks"]
|
|
buffer_size = config["minibatch_buffer_size"]
|
|
if num_stacks < buffer_size:
|
|
logger.warning(
|
|
"In multi-GPU mode you should have at least as many "
|
|
"multi-GPU tower stacks (to load data into on one device) as "
|
|
"you have stack-index slots in the buffer! You have "
|
|
f"configured {num_stacks} stacks and a buffer of size "
|
|
f"{buffer_size}. Setting "
|
|
f"`minibatch_buffer_size={num_stacks}`.")
|
|
config["minibatch_buffer_size"] = num_stacks
|
|
|
|
learner_thread = MultiGPULearnerThread(
|
|
local_worker,
|
|
num_gpus=config["num_gpus"],
|
|
lr=config["lr"],
|
|
train_batch_size=config["train_batch_size"],
|
|
num_multi_gpu_tower_stacks=config["num_multi_gpu_tower_stacks"],
|
|
num_sgd_iter=config["num_sgd_iter"],
|
|
learner_queue_size=config["learner_queue_size"],
|
|
learner_queue_timeout=config["learner_queue_timeout"])
|
|
else:
|
|
learner_thread = LearnerThread(
|
|
local_worker,
|
|
minibatch_buffer_size=config["minibatch_buffer_size"],
|
|
num_sgd_iter=config["num_sgd_iter"],
|
|
learner_queue_size=config["learner_queue_size"],
|
|
learner_queue_timeout=config["learner_queue_timeout"])
|
|
return learner_thread
|
|
|
|
|
|
def get_policy_class(config):
|
|
if config["framework"] == "torch":
|
|
if config["vtrace"]:
|
|
from ray.rllib.agents.impala.vtrace_torch_policy import \
|
|
VTraceTorchPolicy
|
|
return VTraceTorchPolicy
|
|
else:
|
|
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
|
A3CTorchPolicy
|
|
return A3CTorchPolicy
|
|
else:
|
|
if config["vtrace"]:
|
|
return VTraceTFPolicy
|
|
else:
|
|
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
|
return A3CTFPolicy
|
|
|
|
|
|
def validate_config(config):
|
|
if config["num_data_loader_buffers"] != DEPRECATED_VALUE:
|
|
deprecation_warning(
|
|
"num_data_loader_buffers",
|
|
"num_multi_gpu_tower_stacks",
|
|
error=False)
|
|
config["num_multi_gpu_tower_stacks"] = \
|
|
config["num_data_loader_buffers"]
|
|
|
|
if config["entropy_coeff"] < 0.0:
|
|
raise ValueError("`entropy_coeff` must be >= 0.0!")
|
|
|
|
# Check whether worker to aggregation-worker ratio makes sense.
|
|
if config["num_aggregation_workers"] > config["num_workers"]:
|
|
raise ValueError(
|
|
"`num_aggregation_workers` must be smaller than or equal "
|
|
"`num_workers`! Aggregation makes no sense otherwise.")
|
|
elif config["num_aggregation_workers"] > \
|
|
config["num_workers"] / 2:
|
|
logger.warning(
|
|
"`num_aggregation_workers` should be significantly smaller than"
|
|
"`num_workers`! Try setting it to 0.5*`num_workers` or less.")
|
|
|
|
# If two separate optimizers/loss terms used for tf, must also set
|
|
# `_tf_policy_handles_more_than_one_loss` to True.
|
|
if config["_separate_vf_optimizer"] is True:
|
|
# Only supported to tf so far.
|
|
# TODO(sven): Need to change APPO|IMPALATorchPolicies (and the models
|
|
# to return separate sets of weights in order to create the different
|
|
# torch optimizers).
|
|
if config["framework"] not in ["tf", "tf2", "tfe"]:
|
|
raise ValueError(
|
|
"`_separate_vf_optimizer` only supported to tf so far!")
|
|
if config["_tf_policy_handles_more_than_one_loss"] is False:
|
|
logger.warning(
|
|
"`_tf_policy_handles_more_than_one_loss` must be set to True, "
|
|
"for TFPolicy to support more than one loss term/optimizer! "
|
|
"Auto-setting it to True.")
|
|
config["_tf_policy_handles_more_than_one_loss"] = True
|
|
|
|
|
|
# Update worker weights as they finish generating experiences.
|
|
class BroadcastUpdateLearnerWeights:
|
|
def __init__(self, learner_thread, workers, broadcast_interval):
|
|
self.learner_thread = learner_thread
|
|
self.steps_since_broadcast = 0
|
|
self.broadcast_interval = broadcast_interval
|
|
self.workers = workers
|
|
self.weights = workers.local_worker().get_weights()
|
|
|
|
def __call__(self, item):
|
|
actor, batch = item
|
|
self.steps_since_broadcast += 1
|
|
if (self.steps_since_broadcast >= self.broadcast_interval
|
|
and self.learner_thread.weights_updated):
|
|
self.weights = ray.put(self.workers.local_worker().get_weights())
|
|
self.steps_since_broadcast = 0
|
|
self.learner_thread.weights_updated = False
|
|
# Update metrics.
|
|
metrics = _get_shared_metrics()
|
|
metrics.counters["num_weight_broadcasts"] += 1
|
|
actor.set_weights.remote(self.weights, _get_global_vars())
|
|
# Also update global vars of the local worker.
|
|
self.workers.local_worker().set_global_vars(_get_global_vars())
|
|
|
|
|
|
def record_steps_trained(item):
|
|
count, fetches = item
|
|
metrics = _get_shared_metrics()
|
|
# Manually update the steps trained counter since the learner thread
|
|
# is executing outside the pipeline.
|
|
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
|
|
metrics.counters[STEPS_TRAINED_COUNTER] += count
|
|
return item
|
|
|
|
|
|
def gather_experiences_directly(workers, config):
|
|
rollouts = ParallelRollouts(
|
|
workers,
|
|
mode="async",
|
|
num_async=config["max_sample_requests_in_flight_per_worker"])
|
|
|
|
# Augment with replay and concat to desired train batch size.
|
|
train_batches = rollouts \
|
|
.for_each(lambda batch: batch.decompress_if_needed()) \
|
|
.for_each(MixInReplay(
|
|
num_slots=config["replay_buffer_num_slots"],
|
|
replay_proportion=config["replay_proportion"])) \
|
|
.flatten() \
|
|
.combine(
|
|
ConcatBatches(
|
|
min_batch_size=config["train_batch_size"],
|
|
count_steps_by=config["multiagent"]["count_steps_by"],
|
|
))
|
|
|
|
return train_batches
|
|
|
|
|
|
def execution_plan(workers, config):
|
|
if config["num_aggregation_workers"] > 0:
|
|
train_batches = gather_experiences_tree_aggregation(workers, config)
|
|
else:
|
|
train_batches = gather_experiences_directly(workers, config)
|
|
|
|
# Start the learner thread.
|
|
learner_thread = make_learner_thread(workers.local_worker(), config)
|
|
learner_thread.start()
|
|
|
|
# This sub-flow sends experiences to the learner.
|
|
enqueue_op = train_batches \
|
|
.for_each(Enqueue(learner_thread.inqueue))
|
|
# Only need to update workers if there are remote workers.
|
|
if workers.remote_workers():
|
|
enqueue_op = enqueue_op.zip_with_source_actor() \
|
|
.for_each(BroadcastUpdateLearnerWeights(
|
|
learner_thread, workers,
|
|
broadcast_interval=config["broadcast_interval"]))
|
|
|
|
# This sub-flow updates the steps trained counter based on learner output.
|
|
dequeue_op = Dequeue(
|
|
learner_thread.outqueue, check=learner_thread.is_alive) \
|
|
.for_each(record_steps_trained)
|
|
|
|
merged_op = Concurrently(
|
|
[enqueue_op, dequeue_op], mode="async", output_indexes=[1])
|
|
|
|
# Callback for APPO to use to update KL, target network periodically.
|
|
# The input to the callback is the learner fetches dict.
|
|
if config["after_train_step"]:
|
|
merged_op = merged_op.for_each(lambda t: t[1]).for_each(
|
|
config["after_train_step"](workers, config))
|
|
|
|
return StandardMetricsReporting(merged_op, workers, config) \
|
|
.for_each(learner_thread.add_learner_metrics)
|
|
|
|
|
|
ImpalaTrainer = build_trainer(
|
|
name="IMPALA",
|
|
default_config=DEFAULT_CONFIG,
|
|
default_policy=VTraceTFPolicy,
|
|
validate_config=validate_config,
|
|
get_policy_class=get_policy_class,
|
|
execution_plan=execution_plan,
|
|
mixins=[OverrideDefaultResourceRequest])
|