[RLlib] Trainer sub-class IMPALA (instead of using build_trainer()). (#20570)

This commit is contained in:
Sven Mika 2021-11-30 19:08:36 +01:00 committed by GitHub
parent e3e2739164
commit bec719d823
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 175 additions and 166 deletions

View file

@ -1,3 +1,6 @@
from ray.rllib.agents.impala.impala import ImpalaTrainer, DEFAULT_CONFIG from ray.rllib.agents.impala.impala import DEFAULT_CONFIG, ImpalaTrainer
__all__ = ["ImpalaTrainer", "DEFAULT_CONFIG"] __all__ = [
"DEFAULT_CONFIG",
"ImpalaTrainer",
]

View file

@ -1,9 +1,9 @@
import logging import logging
from typing import Optional, Type
import ray import ray
from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy
from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.execution.learner_thread import LearnerThread from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
@ -14,9 +14,10 @@ from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.tune.trainable import Trainable from ray.rllib.utils.typing import PartialTrainerConfigDict, TrainerConfigDict
from ray.tune.utils.placement_groups import PlacementGroupFactory from ray.tune.utils.placement_groups import PlacementGroupFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -126,47 +127,6 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable # yapf: enable
class OverrideDefaultResourceRequest:
@classmethod
@override(Trainable)
def default_resource_request(cls, config):
cf = dict(cls.get_default_config(), **config)
eval_config = cf["evaluation_config"]
# Return PlacementGroupFactory containing all needed resources
# (already properly defined as device bundles).
return PlacementGroupFactory(
bundles=[{
# Driver + Aggregation Workers:
# Force to be on same node to maximize data bandwidth
# between aggregation workers and the learner (driver).
# Aggregation workers tree-aggregate experiences collected
# from RolloutWorkers (n rollout workers map to m
# aggregation workers, where m < n) and always use 1 CPU
# each.
"CPU": cf["num_cpus_for_driver"] +
cf["num_aggregation_workers"],
"GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
}] + [
{
# RolloutWorkers.
"CPU": cf["num_cpus_per_worker"],
"GPU": cf["num_gpus_per_worker"],
} for _ in range(cf["num_workers"])
] + ([
{
# Evaluation (remote) workers.
# Note: The local eval worker is located on the driver CPU.
"CPU": eval_config.get("num_cpus_per_worker",
cf["num_cpus_per_worker"]),
"GPU": eval_config.get("num_gpus_per_worker",
cf["num_gpus_per_worker"]),
} for _ in range(cf["evaluation_num_workers"])
] if cf["evaluation_interval"] else []),
strategy=config.get("placement_strategy", "PACK"))
def make_learner_thread(local_worker, config): def make_learner_thread(local_worker, config):
if not config["simple_optimizer"]: if not config["simple_optimizer"]:
logger.info( logger.info(
@ -203,63 +163,26 @@ def make_learner_thread(local_worker, config):
return learner_thread return learner_thread
def get_policy_class(config): def gather_experiences_directly(workers, config):
if config["framework"] == "torch": rollouts = ParallelRollouts(
if config["vtrace"]: workers,
from ray.rllib.agents.impala.vtrace_torch_policy import \ mode="async",
VTraceTorchPolicy num_async=config["max_sample_requests_in_flight_per_worker"])
return VTraceTorchPolicy
else:
from ray.rllib.agents.a3c.a3c_torch_policy import \
A3CTorchPolicy
return A3CTorchPolicy
else:
if config["vtrace"]:
return VTraceTFPolicy
else:
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
return A3CTFPolicy
# Augment with replay and concat to desired train batch size.
train_batches = rollouts \
.for_each(lambda batch: batch.decompress_if_needed()) \
.for_each(MixInReplay(
num_slots=config["replay_buffer_num_slots"],
replay_proportion=config["replay_proportion"])) \
.flatten() \
.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
))
def validate_config(config): return train_batches
if config["num_data_loader_buffers"] != DEPRECATED_VALUE:
deprecation_warning(
"num_data_loader_buffers",
"num_multi_gpu_tower_stacks",
error=False)
config["num_multi_gpu_tower_stacks"] = \
config["num_data_loader_buffers"]
if config["entropy_coeff"] < 0.0:
raise ValueError("`entropy_coeff` must be >= 0.0!")
# Check whether worker to aggregation-worker ratio makes sense.
if config["num_aggregation_workers"] > config["num_workers"]:
raise ValueError(
"`num_aggregation_workers` must be smaller than or equal "
"`num_workers`! Aggregation makes no sense otherwise.")
elif config["num_aggregation_workers"] > \
config["num_workers"] / 2:
logger.warning(
"`num_aggregation_workers` should be significantly smaller than"
"`num_workers`! Try setting it to 0.5*`num_workers` or less.")
# If two separate optimizers/loss terms used for tf, must also set
# `_tf_policy_handles_more_than_one_loss` to True.
if config["_separate_vf_optimizer"] is True:
# Only supported to tf so far.
# TODO(sven): Need to change APPO|IMPALATorchPolicies (and the models
# to return separate sets of weights in order to create the different
# torch optimizers).
if config["framework"] not in ["tf", "tf2", "tfe"]:
raise ValueError(
"`_separate_vf_optimizer` only supported to tf so far!")
if config["_tf_policy_handles_more_than_one_loss"] is False:
logger.warning(
"`_tf_policy_handles_more_than_one_loss` must be set to True, "
"for TFPolicy to support more than one loss term/optimizer! "
"Auto-setting it to True.")
config["_tf_policy_handles_more_than_one_loss"] = True
# Update worker weights as they finish generating experiences. # Update worker weights as they finish generating experiences.
@ -287,44 +210,87 @@ class BroadcastUpdateLearnerWeights:
self.workers.local_worker().set_global_vars(_get_global_vars()) self.workers.local_worker().set_global_vars(_get_global_vars())
def record_steps_trained(item): class ImpalaTrainer(Trainer):
count, fetches = item @classmethod
metrics = _get_shared_metrics() @override(Trainer)
# Manually update the steps trained counter since the learner thread def get_default_config(cls) -> TrainerConfigDict:
# is executing outside the pipeline. return DEFAULT_CONFIG
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
metrics.counters[STEPS_TRAINED_COUNTER] += count
return item
@override(Trainer)
def get_default_policy_class(self, config: PartialTrainerConfigDict) -> \
Optional[Type[Policy]]:
if config["framework"] == "torch":
if config["vtrace"]:
from ray.rllib.agents.impala.vtrace_torch_policy import \
VTraceTorchPolicy
return VTraceTorchPolicy
else:
from ray.rllib.agents.a3c.a3c_torch_policy import \
A3CTorchPolicy
return A3CTorchPolicy
else:
if config["vtrace"]:
return VTraceTFPolicy
else:
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
return A3CTFPolicy
def gather_experiences_directly(workers, config): @override(Trainer)
rollouts = ParallelRollouts( def validate_config(self, config):
workers, # Call the super class' validation method first.
mode="async", super().validate_config(config)
num_async=config["max_sample_requests_in_flight_per_worker"])
# Augment with replay and concat to desired train batch size. # Check the IMPALA specific config.
train_batches = rollouts \
.for_each(lambda batch: batch.decompress_if_needed()) \
.for_each(MixInReplay(
num_slots=config["replay_buffer_num_slots"],
replay_proportion=config["replay_proportion"])) \
.flatten() \
.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
))
return train_batches if config["num_data_loader_buffers"] != DEPRECATED_VALUE:
deprecation_warning(
"num_data_loader_buffers",
"num_multi_gpu_tower_stacks",
error=False)
config["num_multi_gpu_tower_stacks"] = \
config["num_data_loader_buffers"]
if config["entropy_coeff"] < 0.0:
raise ValueError("`entropy_coeff` must be >= 0.0!")
# Check whether worker to aggregation-worker ratio makes sense.
if config["num_aggregation_workers"] > config["num_workers"]:
raise ValueError(
"`num_aggregation_workers` must be smaller than or equal "
"`num_workers`! Aggregation makes no sense otherwise.")
elif config["num_aggregation_workers"] > \
config["num_workers"] / 2:
logger.warning(
"`num_aggregation_workers` should be significantly smaller "
"than `num_workers`! Try setting it to 0.5*`num_workers` or "
"less.")
# If two separate optimizers/loss terms used for tf, must also set
# `_tf_policy_handles_more_than_one_loss` to True.
if config["_separate_vf_optimizer"] is True:
# Only supported to tf so far.
# TODO(sven): Need to change APPO|IMPALATorchPolicies (and the
# models to return separate sets of weights in order to create
# the different torch optimizers).
if config["framework"] not in ["tf", "tf2", "tfe"]:
raise ValueError(
"`_separate_vf_optimizer` only supported to tf so far!")
if config["_tf_policy_handles_more_than_one_loss"] is False:
logger.warning(
"`_tf_policy_handles_more_than_one_loss` must be set to "
"True, for TFPolicy to support more than one loss "
"term/optimizer! Auto-setting it to True.")
config["_tf_policy_handles_more_than_one_loss"] = True
@staticmethod
@override(Trainer)
def execution_plan(workers, config, **kwargs): def execution_plan(workers, config, **kwargs):
assert len(kwargs) == 0, ( assert len(kwargs) == 0, (
"IMPALA execution_plan does NOT take any additional parameters") "IMPALA execution_plan does NOT take any additional parameters")
if config["num_aggregation_workers"] > 0: if config["num_aggregation_workers"] > 0:
train_batches = gather_experiences_tree_aggregation(workers, config) train_batches = gather_experiences_tree_aggregation(
workers, config)
else: else:
train_batches = gather_experiences_directly(workers, config) train_batches = gather_experiences_directly(workers, config)
@ -342,7 +308,17 @@ def execution_plan(workers, config, **kwargs):
learner_thread, workers, learner_thread, workers,
broadcast_interval=config["broadcast_interval"])) broadcast_interval=config["broadcast_interval"]))
# This sub-flow updates the steps trained counter based on learner output. def record_steps_trained(item):
count, fetches = item
metrics = _get_shared_metrics()
# Manually update the steps trained counter since the learner
# thread is executing outside the pipeline.
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
metrics.counters[STEPS_TRAINED_COUNTER] += count
return item
# This sub-flow updates the steps trained counter based on learner
# output.
dequeue_op = Dequeue( dequeue_op = Dequeue(
learner_thread.outqueue, check=learner_thread.is_alive) \ learner_thread.outqueue, check=learner_thread.is_alive) \
.for_each(record_steps_trained) .for_each(record_steps_trained)
@ -359,12 +335,42 @@ def execution_plan(workers, config, **kwargs):
return StandardMetricsReporting(merged_op, workers, config) \ return StandardMetricsReporting(merged_op, workers, config) \
.for_each(learner_thread.add_learner_metrics) .for_each(learner_thread.add_learner_metrics)
@classmethod
@override(Trainer)
def default_resource_request(cls, config):
cf = dict(cls.get_default_config(), **config)
ImpalaTrainer = build_trainer( eval_config = cf["evaluation_config"]
name="IMPALA",
default_config=DEFAULT_CONFIG, # Return PlacementGroupFactory containing all needed resources
default_policy=VTraceTFPolicy, # (already properly defined as device bundles).
validate_config=validate_config, return PlacementGroupFactory(
get_policy_class=get_policy_class, bundles=[{
execution_plan=execution_plan, # Driver + Aggregation Workers:
mixins=[OverrideDefaultResourceRequest]) # Force to be on same node to maximize data bandwidth
# between aggregation workers and the learner (driver).
# Aggregation workers tree-aggregate experiences collected
# from RolloutWorkers (n rollout workers map to m
# aggregation workers, where m < n) and always use 1 CPU
# each.
"CPU": cf["num_cpus_for_driver"] +
cf["num_aggregation_workers"],
"GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
}] + [
{
# RolloutWorkers.
"CPU": cf["num_cpus_per_worker"],
"GPU": cf["num_gpus_per_worker"],
} for _ in range(cf["num_workers"])
] + ([
{
# Evaluation (remote) workers.
# Note: The local eval worker is located on the driver
# CPU.
"CPU": eval_config.get("num_cpus_per_worker",
cf["num_cpus_per_worker"]),
"GPU": eval_config.get("num_gpus_per_worker",
cf["num_gpus_per_worker"]),
} for _ in range(cf["evaluation_num_workers"])
] if cf["evaluation_interval"] else []),
strategy=config.get("placement_strategy", "PACK"))