ray/rllib/agents/a3c/a3c.py

187 lines
7.4 KiB
Python

import logging
from typing import Any, Dict, Type
from ray.actor import ActorHandle
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests
from ray.rllib.execution.rollout_ops import AsyncGradients
from ray.rllib.execution.train_ops import ApplyGradients
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
APPLY_GRADS_TIMER,
GRAD_WAIT_TIMER,
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
from ray.util.iter import LocalIterator
logger = logging.getLogger(__name__)
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# Should use a critic as a baseline (otherwise don't use value baseline;
# required for using GAE).
"use_critic": True,
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
# Size of rollout batch
"rollout_fragment_length": 10,
# GAE(gamma) parameter
"lambda": 1.0,
# Max global norm for each gradient calculated by worker
"grad_clip": 40.0,
# Learning rate
"lr": 0.0001,
# Learning rate schedule
"lr_schedule": None,
# Value Function Loss coefficient
"vf_loss_coeff": 0.5,
# Entropy coefficient
"entropy_coeff": 0.01,
# Entropy coefficient schedule
"entropy_coeff_schedule": None,
# Min time (in seconds) per reporting.
# This causes not every call to `training_iteration` to be reported,
# but to wait until n seconds have passed and then to summarize the
# thus far collected results.
"min_time_s_per_reporting": 5,
# Workers sample async. Note that this increases the effective
# rollout_fragment_length by up to 5x due to async buffering of batches.
"sample_async": True,
# Use the Trainer's `training_iteration` function instead of `execution_plan`.
# Fixes a severe performance problem with A3C. Setting this to True leads to a
# speedup of up to 3x for a large number of workers and heavier
# gradient computations (e.g. ray/rllib/tuned_examples/a3c/pong-a3c.yaml)).
"_disable_execution_plan_api": True,
})
# __sphinx_doc_end__
# fmt: on
class A3CTrainer(Trainer):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
if config["entropy_coeff"] < 0:
raise ValueError("`entropy_coeff` must be >= 0.0!")
if config["num_workers"] <= 0 and config["sample_async"]:
raise ValueError("`num_workers` for A3C must be >= 1!")
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.agents.a3c.a3c_torch_policy import A3CTorchPolicy
return A3CTorchPolicy
else:
return A3CTFPolicy
def training_iteration(self) -> ResultDict:
# Shortcut.
local_worker = self.workers.local_worker()
# Define the function executed in parallel by all RolloutWorkers to collect
# samples + compute and return gradients (and other information).
def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]:
"""Call sample() and compute_gradients() remotely on workers."""
samples = worker.sample()
grads, infos = worker.compute_gradients(samples)
return {
"grads": grads,
"infos": infos,
"agent_steps": samples.agent_steps(),
"env_steps": samples.env_steps(),
}
# Perform rollouts and gradient calculations asynchronously.
with self._timers[GRAD_WAIT_TIMER]:
# Results are a mapping from ActorHandle (RolloutWorker) to their
# returned gradient calculation results.
async_results: Dict[ActorHandle, Dict] = asynchronous_parallel_requests(
remote_requests_in_flight=self.remote_requests_in_flight,
actors=self.workers.remote_workers(),
ray_wait_timeout_s=0.0,
max_remote_requests_in_flight_per_actor=1,
remote_fn=sample_and_compute_grads,
)
# Loop through all fetched worker-computed gradients (if any)
# and apply them - one by one - to the local worker's model.
# After each apply step (one step per worker that returned some gradients),
# update that particular worker's weights.
global_vars = None
learner_info_builder = LearnerInfoBuilder(num_devices=1)
for worker, results in async_results.items():
for result in results:
# Apply gradients to local worker.
with self._timers[APPLY_GRADS_TIMER]:
local_worker.apply_gradients(result["grads"])
self._timers[APPLY_GRADS_TIMER].push_units_processed(
result["agent_steps"]
)
# Update all step counters.
self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"]
self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"]
self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"]
learner_info_builder.add_learn_on_batch_results_multi_agent(
result["infos"]
)
# Create current global vars.
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}
# Synch updated weights back to the particular worker.
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = local_worker.get_weights(local_worker.get_policies_to_train())
worker.set_weights.remote(weights, global_vars)
# Update global vars of the local worker.
if global_vars:
local_worker.set_global_vars(global_vars)
return learner_info_builder.finalize()
@staticmethod
@override(Trainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "A3C execution_plan does NOT take any additional parameters"
# For A3C, compute policy gradients remotely on the rollout workers.
grads = AsyncGradients(workers)
# Apply the gradients as they arrive. We set update_all to False so
# that only the worker sending the gradient is updated with new
# weights.
train_op = grads.for_each(ApplyGradients(workers, update_all=False))
return StandardMetricsReporting(train_op, workers, config)