mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
222 lines
8.7 KiB
Python
222 lines
8.7 KiB
Python
import logging
|
|
import math
|
|
|
|
from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, A3CTrainer
|
|
from ray.rllib.agents.trainer import Trainer
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
from ray.rllib.execution.common import (
|
|
STEPS_TRAINED_COUNTER,
|
|
STEPS_TRAINED_THIS_ITER_COUNTER,
|
|
)
|
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
|
from ray.rllib.execution.rollout_ops import (
|
|
ParallelRollouts,
|
|
ConcatBatches,
|
|
synchronous_parallel_sample,
|
|
)
|
|
from ray.rllib.execution.train_ops import (
|
|
ComputeGradients,
|
|
AverageGradients,
|
|
ApplyGradients,
|
|
MultiGPUTrainOneStep,
|
|
TrainOneStep,
|
|
)
|
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
|
from ray.rllib.utils import merge_dicts
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.metrics import (
|
|
APPLY_GRADS_TIMER,
|
|
COMPUTE_GRADS_TIMER,
|
|
NUM_AGENT_STEPS_SAMPLED,
|
|
NUM_ENV_STEPS_SAMPLED,
|
|
WORKER_UPDATE_TIMER,
|
|
)
|
|
from ray.rllib.utils.typing import (
|
|
PartialTrainerConfigDict,
|
|
ResultDict,
|
|
TrainerConfigDict,
|
|
)
|
|
from ray.util.iter import LocalIterator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
A2C_DEFAULT_CONFIG = merge_dicts(
|
|
A3C_CONFIG,
|
|
{
|
|
"rollout_fragment_length": 20,
|
|
"min_time_s_per_reporting": 10,
|
|
"sample_async": False,
|
|
# A2C supports microbatching, in which we accumulate gradients over
|
|
# batch of this size until the train batch size is reached. This allows
|
|
# training with batch sizes much larger than can fit in GPU memory.
|
|
# To enable, set this to a value less than the train batch size.
|
|
"microbatch_size": None,
|
|
},
|
|
)
|
|
|
|
|
|
class A2CTrainer(A3CTrainer):
|
|
@classmethod
|
|
@override(A3CTrainer)
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
return A2C_DEFAULT_CONFIG
|
|
|
|
@override(A3CTrainer)
|
|
def validate_config(self, config: TrainerConfigDict) -> None:
|
|
# Call super's validation method.
|
|
super().validate_config(config)
|
|
|
|
if config["microbatch_size"]:
|
|
# Train batch size needs to be significantly larger than microbatch_size.
|
|
if config["train_batch_size"] / config["microbatch_size"] < 3:
|
|
logger.warning(
|
|
"`train_batch_size` should be considerably larger (at least 3x) "
|
|
"than `microbatch_size` for a microbatching setup to make sense!"
|
|
)
|
|
# Rollout fragment length needs to be less than microbatch_size.
|
|
if config["rollout_fragment_length"] > config["microbatch_size"]:
|
|
logger.warning(
|
|
"`rollout_fragment_length` should not be larger than "
|
|
"`microbatch_size` (try setting them to the same value)! "
|
|
"Otherwise, microbatches of desired size won't be achievable."
|
|
)
|
|
|
|
@override(Trainer)
|
|
def setup(self, config: PartialTrainerConfigDict):
|
|
super().setup(config)
|
|
|
|
# Create a microbatch variable for collecting gradients on microbatches'.
|
|
# These gradients will be accumulated on-the-fly and applied at once (once train
|
|
# batch size has been collected) to the model.
|
|
if (
|
|
self.config["_disable_execution_plan_api"] is True
|
|
and self.config["microbatch_size"]
|
|
):
|
|
self._microbatches_grads = None
|
|
self._microbatches_counts = self._num_microbatches = 0
|
|
|
|
@override(A3CTrainer)
|
|
def training_iteration(self) -> ResultDict:
|
|
# W/o microbatching: Identical to Trainer's default implementation.
|
|
# Only difference to a default Trainer being the value function loss term
|
|
# and its value computations alongside each action.
|
|
if self.config["microbatch_size"] is None:
|
|
return Trainer.training_iteration(self)
|
|
|
|
# In microbatch mode, we want to compute gradients on experience
|
|
# microbatches, average a number of these microbatches, and then
|
|
# apply the averaged gradient in one SGD step. This conserves GPU
|
|
# memory, allowing for extremely large experience batches to be
|
|
# used.
|
|
if self._by_agent_steps:
|
|
train_batch = synchronous_parallel_sample(
|
|
worker_set=self.workers, max_agent_steps=self.config["microbatch_size"]
|
|
)
|
|
else:
|
|
train_batch = synchronous_parallel_sample(
|
|
worker_set=self.workers, max_env_steps=self.config["microbatch_size"]
|
|
)
|
|
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
|
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
|
|
|
|
with self._timers[COMPUTE_GRADS_TIMER]:
|
|
grad, info = self.workers.local_worker().compute_gradients(
|
|
train_batch, single_agent=True
|
|
)
|
|
# New microbatch accumulation phase.
|
|
if self._microbatches_grads is None:
|
|
self._microbatches_grads = grad
|
|
# Existing gradients: Accumulate new gradients on top of existing ones.
|
|
else:
|
|
for i, g in enumerate(grad):
|
|
self._microbatches_grads[i] += g
|
|
self._microbatches_counts += train_batch.count
|
|
self._num_microbatches += 1
|
|
|
|
# If `train_batch_size` reached: Accumulate gradients and apply.
|
|
num_microbatches = math.ceil(
|
|
self.config["train_batch_size"] / self.config["microbatch_size"]
|
|
)
|
|
if self._num_microbatches >= num_microbatches:
|
|
# Update counters.
|
|
self._counters[STEPS_TRAINED_COUNTER] += self._microbatches_counts
|
|
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = self._microbatches_counts
|
|
|
|
# Apply gradients.
|
|
apply_timer = self._timers[APPLY_GRADS_TIMER]
|
|
with apply_timer:
|
|
self.workers.local_worker().apply_gradients(self._microbatches_grads)
|
|
apply_timer.push_units_processed(self._microbatches_counts)
|
|
|
|
# Reset microbatch information.
|
|
self._microbatches_grads = None
|
|
self._microbatches_counts = self._num_microbatches = 0
|
|
|
|
# Also update global vars of the local worker.
|
|
# Create current global vars.
|
|
global_vars = {
|
|
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
|
|
}
|
|
with self._timers[WORKER_UPDATE_TIMER]:
|
|
self.workers.sync_weights(
|
|
policies=self.workers.local_worker().get_policies_to_train(),
|
|
global_vars=global_vars,
|
|
)
|
|
|
|
train_results = {DEFAULT_POLICY_ID: info}
|
|
|
|
return train_results
|
|
|
|
@staticmethod
|
|
@override(Trainer)
|
|
def execution_plan(
|
|
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
|
) -> LocalIterator[dict]:
|
|
assert (
|
|
len(kwargs) == 0
|
|
), "A2C execution_plan does NOT take any additional parameters"
|
|
|
|
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
|
|
|
if config["microbatch_size"]:
|
|
num_microbatches = math.ceil(
|
|
config["train_batch_size"] / config["microbatch_size"]
|
|
)
|
|
# In microbatch mode, we want to compute gradients on experience
|
|
# microbatches, average a number of these microbatches, and then
|
|
# apply the averaged gradient in one SGD step. This conserves GPU
|
|
# memory, allowing for extremely large experience batches to be
|
|
# used.
|
|
train_op = (
|
|
rollouts.combine(
|
|
ConcatBatches(
|
|
min_batch_size=config["microbatch_size"],
|
|
count_steps_by=config["multiagent"]["count_steps_by"],
|
|
)
|
|
)
|
|
.for_each(ComputeGradients(workers)) # (grads, info)
|
|
.batch(num_microbatches) # List[(grads, info)]
|
|
.for_each(AverageGradients()) # (avg_grads, info)
|
|
.for_each(ApplyGradients(workers))
|
|
)
|
|
else:
|
|
# In normal mode, we execute one SGD step per each train batch.
|
|
if config["simple_optimizer"]:
|
|
train_step_op = TrainOneStep(workers)
|
|
else:
|
|
train_step_op = MultiGPUTrainOneStep(
|
|
workers=workers,
|
|
sgd_minibatch_size=config["train_batch_size"],
|
|
num_sgd_iter=1,
|
|
num_gpus=config["num_gpus"],
|
|
_fake_gpus=config["_fake_gpus"],
|
|
)
|
|
|
|
train_op = rollouts.combine(
|
|
ConcatBatches(
|
|
min_batch_size=config["train_batch_size"],
|
|
count_steps_by=config["multiagent"]["count_steps_by"],
|
|
)
|
|
).for_each(train_step_op)
|
|
|
|
return StandardMetricsReporting(train_op, workers, config)
|