mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Trainer sub-class IMPALA (instead of using build_trainer()
). (#20570)
This commit is contained in:
parent
e3e2739164
commit
bec719d823
2 changed files with 175 additions and 166 deletions
|
@ -1,3 +1,6 @@
|
|||
from ray.rllib.agents.impala.impala import ImpalaTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.impala.impala import DEFAULT_CONFIG, ImpalaTrainer
|
||||
|
||||
__all__ = ["ImpalaTrainer", "DEFAULT_CONFIG"]
|
||||
__all__ = [
|
||||
"DEFAULT_CONFIG",
|
||||
"ImpalaTrainer",
|
||||
]
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import logging
|
||||
from typing import Optional, Type
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.execution.learner_thread import LearnerThread
|
||||
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
|
||||
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
|
||||
|
@ -14,9 +14,10 @@ from ray.rllib.execution.replay_ops import MixInReplay
|
|||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.rllib.utils.typing import PartialTrainerConfigDict, TrainerConfigDict
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -126,47 +127,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class OverrideDefaultResourceRequest:
|
||||
@classmethod
|
||||
@override(Trainable)
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls.get_default_config(), **config)
|
||||
|
||||
eval_config = cf["evaluation_config"]
|
||||
|
||||
# Return PlacementGroupFactory containing all needed resources
|
||||
# (already properly defined as device bundles).
|
||||
return PlacementGroupFactory(
|
||||
bundles=[{
|
||||
# Driver + Aggregation Workers:
|
||||
# Force to be on same node to maximize data bandwidth
|
||||
# between aggregation workers and the learner (driver).
|
||||
# Aggregation workers tree-aggregate experiences collected
|
||||
# from RolloutWorkers (n rollout workers map to m
|
||||
# aggregation workers, where m < n) and always use 1 CPU
|
||||
# each.
|
||||
"CPU": cf["num_cpus_for_driver"] +
|
||||
cf["num_aggregation_workers"],
|
||||
"GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
|
||||
}] + [
|
||||
{
|
||||
# RolloutWorkers.
|
||||
"CPU": cf["num_cpus_per_worker"],
|
||||
"GPU": cf["num_gpus_per_worker"],
|
||||
} for _ in range(cf["num_workers"])
|
||||
] + ([
|
||||
{
|
||||
# Evaluation (remote) workers.
|
||||
# Note: The local eval worker is located on the driver CPU.
|
||||
"CPU": eval_config.get("num_cpus_per_worker",
|
||||
cf["num_cpus_per_worker"]),
|
||||
"GPU": eval_config.get("num_gpus_per_worker",
|
||||
cf["num_gpus_per_worker"]),
|
||||
} for _ in range(cf["evaluation_num_workers"])
|
||||
] if cf["evaluation_interval"] else []),
|
||||
strategy=config.get("placement_strategy", "PACK"))
|
||||
|
||||
|
||||
def make_learner_thread(local_worker, config):
|
||||
if not config["simple_optimizer"]:
|
||||
logger.info(
|
||||
|
@ -203,63 +163,26 @@ def make_learner_thread(local_worker, config):
|
|||
return learner_thread
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["framework"] == "torch":
|
||||
if config["vtrace"]:
|
||||
from ray.rllib.agents.impala.vtrace_torch_policy import \
|
||||
VTraceTorchPolicy
|
||||
return VTraceTorchPolicy
|
||||
else:
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
||||
A3CTorchPolicy
|
||||
return A3CTorchPolicy
|
||||
else:
|
||||
if config["vtrace"]:
|
||||
return VTraceTFPolicy
|
||||
else:
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
return A3CTFPolicy
|
||||
def gather_experiences_directly(workers, config):
|
||||
rollouts = ParallelRollouts(
|
||||
workers,
|
||||
mode="async",
|
||||
num_async=config["max_sample_requests_in_flight_per_worker"])
|
||||
|
||||
# Augment with replay and concat to desired train batch size.
|
||||
train_batches = rollouts \
|
||||
.for_each(lambda batch: batch.decompress_if_needed()) \
|
||||
.for_each(MixInReplay(
|
||||
num_slots=config["replay_buffer_num_slots"],
|
||||
replay_proportion=config["replay_proportion"])) \
|
||||
.flatten() \
|
||||
.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
))
|
||||
|
||||
def validate_config(config):
|
||||
if config["num_data_loader_buffers"] != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
"num_data_loader_buffers",
|
||||
"num_multi_gpu_tower_stacks",
|
||||
error=False)
|
||||
config["num_multi_gpu_tower_stacks"] = \
|
||||
config["num_data_loader_buffers"]
|
||||
|
||||
if config["entropy_coeff"] < 0.0:
|
||||
raise ValueError("`entropy_coeff` must be >= 0.0!")
|
||||
|
||||
# Check whether worker to aggregation-worker ratio makes sense.
|
||||
if config["num_aggregation_workers"] > config["num_workers"]:
|
||||
raise ValueError(
|
||||
"`num_aggregation_workers` must be smaller than or equal "
|
||||
"`num_workers`! Aggregation makes no sense otherwise.")
|
||||
elif config["num_aggregation_workers"] > \
|
||||
config["num_workers"] / 2:
|
||||
logger.warning(
|
||||
"`num_aggregation_workers` should be significantly smaller than"
|
||||
"`num_workers`! Try setting it to 0.5*`num_workers` or less.")
|
||||
|
||||
# If two separate optimizers/loss terms used for tf, must also set
|
||||
# `_tf_policy_handles_more_than_one_loss` to True.
|
||||
if config["_separate_vf_optimizer"] is True:
|
||||
# Only supported to tf so far.
|
||||
# TODO(sven): Need to change APPO|IMPALATorchPolicies (and the models
|
||||
# to return separate sets of weights in order to create the different
|
||||
# torch optimizers).
|
||||
if config["framework"] not in ["tf", "tf2", "tfe"]:
|
||||
raise ValueError(
|
||||
"`_separate_vf_optimizer` only supported to tf so far!")
|
||||
if config["_tf_policy_handles_more_than_one_loss"] is False:
|
||||
logger.warning(
|
||||
"`_tf_policy_handles_more_than_one_loss` must be set to True, "
|
||||
"for TFPolicy to support more than one loss term/optimizer! "
|
||||
"Auto-setting it to True.")
|
||||
config["_tf_policy_handles_more_than_one_loss"] = True
|
||||
return train_batches
|
||||
|
||||
|
||||
# Update worker weights as they finish generating experiences.
|
||||
|
@ -287,84 +210,167 @@ class BroadcastUpdateLearnerWeights:
|
|||
self.workers.local_worker().set_global_vars(_get_global_vars())
|
||||
|
||||
|
||||
def record_steps_trained(item):
|
||||
count, fetches = item
|
||||
metrics = _get_shared_metrics()
|
||||
# Manually update the steps trained counter since the learner thread
|
||||
# is executing outside the pipeline.
|
||||
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += count
|
||||
return item
|
||||
class ImpalaTrainer(Trainer):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: PartialTrainerConfigDict) -> \
|
||||
Optional[Type[Policy]]:
|
||||
if config["framework"] == "torch":
|
||||
if config["vtrace"]:
|
||||
from ray.rllib.agents.impala.vtrace_torch_policy import \
|
||||
VTraceTorchPolicy
|
||||
return VTraceTorchPolicy
|
||||
else:
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
||||
A3CTorchPolicy
|
||||
return A3CTorchPolicy
|
||||
else:
|
||||
if config["vtrace"]:
|
||||
return VTraceTFPolicy
|
||||
else:
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
return A3CTFPolicy
|
||||
|
||||
def gather_experiences_directly(workers, config):
|
||||
rollouts = ParallelRollouts(
|
||||
workers,
|
||||
mode="async",
|
||||
num_async=config["max_sample_requests_in_flight_per_worker"])
|
||||
@override(Trainer)
|
||||
def validate_config(self, config):
|
||||
# Call the super class' validation method first.
|
||||
super().validate_config(config)
|
||||
|
||||
# Augment with replay and concat to desired train batch size.
|
||||
train_batches = rollouts \
|
||||
.for_each(lambda batch: batch.decompress_if_needed()) \
|
||||
.for_each(MixInReplay(
|
||||
num_slots=config["replay_buffer_num_slots"],
|
||||
replay_proportion=config["replay_proportion"])) \
|
||||
.flatten() \
|
||||
.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
))
|
||||
# Check the IMPALA specific config.
|
||||
|
||||
return train_batches
|
||||
if config["num_data_loader_buffers"] != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
"num_data_loader_buffers",
|
||||
"num_multi_gpu_tower_stacks",
|
||||
error=False)
|
||||
config["num_multi_gpu_tower_stacks"] = \
|
||||
config["num_data_loader_buffers"]
|
||||
|
||||
if config["entropy_coeff"] < 0.0:
|
||||
raise ValueError("`entropy_coeff` must be >= 0.0!")
|
||||
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
"IMPALA execution_plan does NOT take any additional parameters")
|
||||
# Check whether worker to aggregation-worker ratio makes sense.
|
||||
if config["num_aggregation_workers"] > config["num_workers"]:
|
||||
raise ValueError(
|
||||
"`num_aggregation_workers` must be smaller than or equal "
|
||||
"`num_workers`! Aggregation makes no sense otherwise.")
|
||||
elif config["num_aggregation_workers"] > \
|
||||
config["num_workers"] / 2:
|
||||
logger.warning(
|
||||
"`num_aggregation_workers` should be significantly smaller "
|
||||
"than `num_workers`! Try setting it to 0.5*`num_workers` or "
|
||||
"less.")
|
||||
|
||||
if config["num_aggregation_workers"] > 0:
|
||||
train_batches = gather_experiences_tree_aggregation(workers, config)
|
||||
else:
|
||||
train_batches = gather_experiences_directly(workers, config)
|
||||
# If two separate optimizers/loss terms used for tf, must also set
|
||||
# `_tf_policy_handles_more_than_one_loss` to True.
|
||||
if config["_separate_vf_optimizer"] is True:
|
||||
# Only supported to tf so far.
|
||||
# TODO(sven): Need to change APPO|IMPALATorchPolicies (and the
|
||||
# models to return separate sets of weights in order to create
|
||||
# the different torch optimizers).
|
||||
if config["framework"] not in ["tf", "tf2", "tfe"]:
|
||||
raise ValueError(
|
||||
"`_separate_vf_optimizer` only supported to tf so far!")
|
||||
if config["_tf_policy_handles_more_than_one_loss"] is False:
|
||||
logger.warning(
|
||||
"`_tf_policy_handles_more_than_one_loss` must be set to "
|
||||
"True, for TFPolicy to support more than one loss "
|
||||
"term/optimizer! Auto-setting it to True.")
|
||||
config["_tf_policy_handles_more_than_one_loss"] = True
|
||||
|
||||
# Start the learner thread.
|
||||
learner_thread = make_learner_thread(workers.local_worker(), config)
|
||||
learner_thread.start()
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
"IMPALA execution_plan does NOT take any additional parameters")
|
||||
|
||||
# This sub-flow sends experiences to the learner.
|
||||
enqueue_op = train_batches \
|
||||
.for_each(Enqueue(learner_thread.inqueue))
|
||||
# Only need to update workers if there are remote workers.
|
||||
if workers.remote_workers():
|
||||
enqueue_op = enqueue_op.zip_with_source_actor() \
|
||||
.for_each(BroadcastUpdateLearnerWeights(
|
||||
learner_thread, workers,
|
||||
broadcast_interval=config["broadcast_interval"]))
|
||||
if config["num_aggregation_workers"] > 0:
|
||||
train_batches = gather_experiences_tree_aggregation(
|
||||
workers, config)
|
||||
else:
|
||||
train_batches = gather_experiences_directly(workers, config)
|
||||
|
||||
# This sub-flow updates the steps trained counter based on learner output.
|
||||
dequeue_op = Dequeue(
|
||||
# Start the learner thread.
|
||||
learner_thread = make_learner_thread(workers.local_worker(), config)
|
||||
learner_thread.start()
|
||||
|
||||
# This sub-flow sends experiences to the learner.
|
||||
enqueue_op = train_batches \
|
||||
.for_each(Enqueue(learner_thread.inqueue))
|
||||
# Only need to update workers if there are remote workers.
|
||||
if workers.remote_workers():
|
||||
enqueue_op = enqueue_op.zip_with_source_actor() \
|
||||
.for_each(BroadcastUpdateLearnerWeights(
|
||||
learner_thread, workers,
|
||||
broadcast_interval=config["broadcast_interval"]))
|
||||
|
||||
def record_steps_trained(item):
|
||||
count, fetches = item
|
||||
metrics = _get_shared_metrics()
|
||||
# Manually update the steps trained counter since the learner
|
||||
# thread is executing outside the pipeline.
|
||||
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += count
|
||||
return item
|
||||
|
||||
# This sub-flow updates the steps trained counter based on learner
|
||||
# output.
|
||||
dequeue_op = Dequeue(
|
||||
learner_thread.outqueue, check=learner_thread.is_alive) \
|
||||
.for_each(record_steps_trained)
|
||||
.for_each(record_steps_trained)
|
||||
|
||||
merged_op = Concurrently(
|
||||
[enqueue_op, dequeue_op], mode="async", output_indexes=[1])
|
||||
merged_op = Concurrently(
|
||||
[enqueue_op, dequeue_op], mode="async", output_indexes=[1])
|
||||
|
||||
# Callback for APPO to use to update KL, target network periodically.
|
||||
# The input to the callback is the learner fetches dict.
|
||||
if config["after_train_step"]:
|
||||
merged_op = merged_op.for_each(lambda t: t[1]).for_each(
|
||||
config["after_train_step"](workers, config))
|
||||
# Callback for APPO to use to update KL, target network periodically.
|
||||
# The input to the callback is the learner fetches dict.
|
||||
if config["after_train_step"]:
|
||||
merged_op = merged_op.for_each(lambda t: t[1]).for_each(
|
||||
config["after_train_step"](workers, config))
|
||||
|
||||
return StandardMetricsReporting(merged_op, workers, config) \
|
||||
.for_each(learner_thread.add_learner_metrics)
|
||||
return StandardMetricsReporting(merged_op, workers, config) \
|
||||
.for_each(learner_thread.add_learner_metrics)
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls.get_default_config(), **config)
|
||||
|
||||
ImpalaTrainer = build_trainer(
|
||||
name="IMPALA",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=VTraceTFPolicy,
|
||||
validate_config=validate_config,
|
||||
get_policy_class=get_policy_class,
|
||||
execution_plan=execution_plan,
|
||||
mixins=[OverrideDefaultResourceRequest])
|
||||
eval_config = cf["evaluation_config"]
|
||||
|
||||
# Return PlacementGroupFactory containing all needed resources
|
||||
# (already properly defined as device bundles).
|
||||
return PlacementGroupFactory(
|
||||
bundles=[{
|
||||
# Driver + Aggregation Workers:
|
||||
# Force to be on same node to maximize data bandwidth
|
||||
# between aggregation workers and the learner (driver).
|
||||
# Aggregation workers tree-aggregate experiences collected
|
||||
# from RolloutWorkers (n rollout workers map to m
|
||||
# aggregation workers, where m < n) and always use 1 CPU
|
||||
# each.
|
||||
"CPU": cf["num_cpus_for_driver"] +
|
||||
cf["num_aggregation_workers"],
|
||||
"GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
|
||||
}] + [
|
||||
{
|
||||
# RolloutWorkers.
|
||||
"CPU": cf["num_cpus_per_worker"],
|
||||
"GPU": cf["num_gpus_per_worker"],
|
||||
} for _ in range(cf["num_workers"])
|
||||
] + ([
|
||||
{
|
||||
# Evaluation (remote) workers.
|
||||
# Note: The local eval worker is located on the driver
|
||||
# CPU.
|
||||
"CPU": eval_config.get("num_cpus_per_worker",
|
||||
cf["num_cpus_per_worker"]),
|
||||
"GPU": eval_config.get("num_gpus_per_worker",
|
||||
cf["num_gpus_per_worker"]),
|
||||
} for _ in range(cf["evaluation_num_workers"])
|
||||
] if cf["evaluation_interval"] else []),
|
||||
strategy=config.get("placement_strategy", "PACK"))
|
||||
|
|
Loading…
Add table
Reference in a new issue