ray/rllib/agents/a3c/a2c.py

234 lines
8.6 KiB
Python

import logging
import math
from typing import Optional
from ray.rllib.agents.a3c.a3c import A3CConfig, A3CTrainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
)
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics import (
APPLY_GRADS_TIMER,
COMPUTE_GRADS_TIMER,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
WORKER_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
ResultDict,
TrainerConfigDict,
)
logger = logging.getLogger(__name__)
class A2CConfig(A3CConfig):
"""Defines a A2CTrainer configuration class from which a new Trainer can be built.
Example:
>>> from ray import tune
>>> config = A2CConfig().training(lr=0.01, grad_clip=30.0)\
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=2)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
Example:
>>> config = A2CConfig()
>>> # Print out some default values.
>>> print(config.sample_async)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), use_critic=False)
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "A2C",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self):
"""Initializes a A2CConfig instance."""
super().__init__(trainer_class=A2CTrainer)
# fmt: off
# __sphinx_doc_begin__
# A2C specific settings:
self.microbatch_size = None
# Override some of A3CConfig's default values with A2C-specific values.
self.rollout_fragment_length = 20
self.sample_async = False
self.min_time_s_per_reporting = 10
# __sphinx_doc_end__
# fmt: on
@override(A3CConfig)
def training(
self,
*,
microbatch_size: Optional[int] = None,
**kwargs,
) -> "A2CConfig":
"""Sets the training related configuration.
Args:
microbatch_size: A2C supports microbatching, in which we accumulate
gradients over batch of this size until the train batch size is reached.
This allows training with batch sizes much larger than can fit in GPU
memory. To enable, set this to a value less than the train batch size.
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if microbatch_size is not None:
self.microbatch_size = microbatch_size
return self
class A2CTrainer(A3CTrainer):
@classmethod
@override(A3CTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return A2CConfig().to_dict()
@override(A3CTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
if config["microbatch_size"]:
# Train batch size needs to be significantly larger than microbatch_size.
if config["train_batch_size"] / config["microbatch_size"] < 3:
logger.warning(
"`train_batch_size` should be considerably larger (at least 3x) "
"than `microbatch_size` for a microbatching setup to make sense!"
)
# Rollout fragment length needs to be less than microbatch_size.
if config["rollout_fragment_length"] > config["microbatch_size"]:
logger.warning(
"`rollout_fragment_length` should not be larger than "
"`microbatch_size` (try setting them to the same value)! "
"Otherwise, microbatches of desired size won't be achievable."
)
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
super().setup(config)
# Create a microbatch variable for collecting gradients on microbatches'.
# These gradients will be accumulated on-the-fly and applied at once (once train
# batch size has been collected) to the model.
if (
self.config["_disable_execution_plan_api"] is True
and self.config["microbatch_size"]
):
self._microbatches_grads = None
self._microbatches_counts = self._num_microbatches = 0
@override(A3CTrainer)
def training_iteration(self) -> ResultDict:
# W/o microbatching: Identical to Trainer's default implementation.
# Only difference to a default Trainer being the value function loss term
# and its value computations alongside each action.
if self.config["microbatch_size"] is None:
return Trainer.training_iteration(self)
# In microbatch mode, we want to compute gradients on experience
# microbatches, average a number of these microbatches, and then
# apply the averaged gradient in one SGD step. This conserves GPU
# memory, allowing for extremely large experience batches to be
# used.
if self._by_agent_steps:
train_batch = synchronous_parallel_sample(
worker_set=self.workers, max_agent_steps=self.config["microbatch_size"]
)
else:
train_batch = synchronous_parallel_sample(
worker_set=self.workers, max_env_steps=self.config["microbatch_size"]
)
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
with self._timers[COMPUTE_GRADS_TIMER]:
grad, info = self.workers.local_worker().compute_gradients(
train_batch, single_agent=True
)
# New microbatch accumulation phase.
if self._microbatches_grads is None:
self._microbatches_grads = grad
# Existing gradients: Accumulate new gradients on top of existing ones.
else:
for i, g in enumerate(grad):
self._microbatches_grads[i] += g
self._microbatches_counts += train_batch.count
self._num_microbatches += 1
# If `train_batch_size` reached: Accumulate gradients and apply.
num_microbatches = math.ceil(
self.config["train_batch_size"] / self.config["microbatch_size"]
)
if self._num_microbatches >= num_microbatches:
# Update counters.
self._counters[STEPS_TRAINED_COUNTER] += self._microbatches_counts
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = self._microbatches_counts
# Apply gradients.
apply_timer = self._timers[APPLY_GRADS_TIMER]
with apply_timer:
self.workers.local_worker().apply_gradients(self._microbatches_grads)
apply_timer.push_units_processed(self._microbatches_counts)
# Reset microbatch information.
self._microbatches_grads = None
self._microbatches_counts = self._num_microbatches = 0
# Also update global vars of the local worker.
# Create current global vars.
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}
with self._timers[WORKER_UPDATE_TIMER]:
self.workers.sync_weights(
policies=self.workers.local_worker().get_policies_to_train(),
global_vars=global_vars,
)
train_results = {DEFAULT_POLICY_ID: info}
return train_results
# Deprecated: Use ray.rllib.agents.a3c.A2CConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(A2CConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.a3c.a2c.A2C_DEFAULT_CONFIG",
new="ray.rllib.agents.a3c.a2c.A2CConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
A2C_DEFAULT_CONFIG = _deprecated_default_config()