ray/rllib/agents/a3c/a2c.py

222 lines
8.7 KiB
Python

import logging
import math
from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, A3CTrainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
)
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
ConcatBatches,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
ComputeGradients,
AverageGradients,
ApplyGradients,
MultiGPUTrainOneStep,
TrainOneStep,
)
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
APPLY_GRADS_TIMER,
COMPUTE_GRADS_TIMER,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
WORKER_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
ResultDict,
TrainerConfigDict,
)
from ray.util.iter import LocalIterator
logger = logging.getLogger(__name__)
A2C_DEFAULT_CONFIG = merge_dicts(
A3C_CONFIG,
{
"rollout_fragment_length": 20,
"min_time_s_per_reporting": 10,
"sample_async": False,
# A2C supports microbatching, in which we accumulate gradients over
# batch of this size until the train batch size is reached. This allows
# training with batch sizes much larger than can fit in GPU memory.
# To enable, set this to a value less than the train batch size.
"microbatch_size": None,
},
)
class A2CTrainer(A3CTrainer):
@classmethod
@override(A3CTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return A2C_DEFAULT_CONFIG
@override(A3CTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
if config["microbatch_size"]:
# Train batch size needs to be significantly larger than microbatch_size.
if config["train_batch_size"] / config["microbatch_size"] < 3:
logger.warning(
"`train_batch_size` should be considerably larger (at least 3x) "
"than `microbatch_size` for a microbatching setup to make sense!"
)
# Rollout fragment length needs to be less than microbatch_size.
if config["rollout_fragment_length"] > config["microbatch_size"]:
logger.warning(
"`rollout_fragment_length` should not be larger than "
"`microbatch_size` (try setting them to the same value)! "
"Otherwise, microbatches of desired size won't be achievable."
)
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
super().setup(config)
# Create a microbatch variable for collecting gradients on microbatches'.
# These gradients will be accumulated on-the-fly and applied at once (once train
# batch size has been collected) to the model.
if (
self.config["_disable_execution_plan_api"] is True
and self.config["microbatch_size"]
):
self._microbatches_grads = None
self._microbatches_counts = self._num_microbatches = 0
@override(A3CTrainer)
def training_iteration(self) -> ResultDict:
# W/o microbatching: Identical to Trainer's default implementation.
# Only difference to a default Trainer being the value function loss term
# and its value computations alongside each action.
if self.config["microbatch_size"] is None:
return Trainer.training_iteration(self)
# In microbatch mode, we want to compute gradients on experience
# microbatches, average a number of these microbatches, and then
# apply the averaged gradient in one SGD step. This conserves GPU
# memory, allowing for extremely large experience batches to be
# used.
if self._by_agent_steps:
train_batch = synchronous_parallel_sample(
worker_set=self.workers, max_agent_steps=self.config["microbatch_size"]
)
else:
train_batch = synchronous_parallel_sample(
worker_set=self.workers, max_env_steps=self.config["microbatch_size"]
)
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
with self._timers[COMPUTE_GRADS_TIMER]:
grad, info = self.workers.local_worker().compute_gradients(
train_batch, single_agent=True
)
# New microbatch accumulation phase.
if self._microbatches_grads is None:
self._microbatches_grads = grad
# Existing gradients: Accumulate new gradients on top of existing ones.
else:
for i, g in enumerate(grad):
self._microbatches_grads[i] += g
self._microbatches_counts += train_batch.count
self._num_microbatches += 1
# If `train_batch_size` reached: Accumulate gradients and apply.
num_microbatches = math.ceil(
self.config["train_batch_size"] / self.config["microbatch_size"]
)
if self._num_microbatches >= num_microbatches:
# Update counters.
self._counters[STEPS_TRAINED_COUNTER] += self._microbatches_counts
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = self._microbatches_counts
# Apply gradients.
apply_timer = self._timers[APPLY_GRADS_TIMER]
with apply_timer:
self.workers.local_worker().apply_gradients(self._microbatches_grads)
apply_timer.push_units_processed(self._microbatches_counts)
# Reset microbatch information.
self._microbatches_grads = None
self._microbatches_counts = self._num_microbatches = 0
# Also update global vars of the local worker.
# Create current global vars.
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}
with self._timers[WORKER_UPDATE_TIMER]:
self.workers.sync_weights(
policies=self.workers.local_worker().get_policies_to_train(),
global_vars=global_vars,
)
train_results = {DEFAULT_POLICY_ID: info}
return train_results
@staticmethod
@override(Trainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "A2C execution_plan does NOT take any additional parameters"
rollouts = ParallelRollouts(workers, mode="bulk_sync")
if config["microbatch_size"]:
num_microbatches = math.ceil(
config["train_batch_size"] / config["microbatch_size"]
)
# In microbatch mode, we want to compute gradients on experience
# microbatches, average a number of these microbatches, and then
# apply the averaged gradient in one SGD step. This conserves GPU
# memory, allowing for extremely large experience batches to be
# used.
train_op = (
rollouts.combine(
ConcatBatches(
min_batch_size=config["microbatch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
)
.for_each(ComputeGradients(workers)) # (grads, info)
.batch(num_microbatches) # List[(grads, info)]
.for_each(AverageGradients()) # (avg_grads, info)
.for_each(ApplyGradients(workers))
)
else:
# In normal mode, we execute one SGD step per each train batch.
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
_fake_gpus=config["_fake_gpus"],
)
train_op = rollouts.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
).for_each(train_step_op)
return StandardMetricsReporting(train_op, workers, config)