ray/rllib/agents/a3c/a3c.py
gjoliver 99a0088233
[RLlib] Unify the way we create local replay buffer for all agents (#19627)
* [RLlib] Unify the way we create and use LocalReplayBuffer for all the agents.

This change
1. Get rid of the try...except clause when we call execution_plan(),
   and get rid of the Deprecation warning as a result.
2. Fix the execution_plan() call in Trainer._try_recover() too.
3. Most importantly, makes it much easier to create and use different types
   of local replay buffers for all our agents.
   E.g., allow us to easily create a reservoir sampling replay buffer for
   APPO agent for Riot in the near future.
* Introduce explicit configuration for replay buffer types.
* Fix is_training key error.
* actually deprecate buffer_size field.
2021-10-26 20:56:02 +02:00

113 lines
3.9 KiB
Python

import logging
from typing import Optional, Type
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.execution.rollout_ops import AsyncGradients
from ray.rllib.execution.train_ops import ApplyGradients
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.util.iter import LocalIterator
from ray.rllib.policy.policy import Policy
logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# Should use a critic as a baseline (otherwise don't use value baseline;
# required for using GAE).
"use_critic": True,
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
# Size of rollout batch
"rollout_fragment_length": 10,
# GAE(gamma) parameter
"lambda": 1.0,
# Max global norm for each gradient calculated by worker
"grad_clip": 40.0,
# Learning rate
"lr": 0.0001,
# Learning rate schedule
"lr_schedule": None,
# Value Function Loss coefficient
"vf_loss_coeff": 0.5,
# Entropy coefficient
"entropy_coeff": 0.01,
# Entropy coefficient schedule
"entropy_coeff_schedule": None,
# Min time per iteration
"min_iter_time_s": 5,
# Workers sample async. Note that this increases the effective
# rollout_fragment_length by up to 5x due to async buffering of batches.
"sample_async": True,
})
# __sphinx_doc_end__
# yapf: enable
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Optional[Type[Policy]]: The Policy class to use with DQNTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
from ray.rllib.agents.a3c.a3c_torch_policy import \
A3CTorchPolicy
return A3CTorchPolicy
else:
return A3CTFPolicy
def validate_config(config: TrainerConfigDict) -> None:
"""Checks and updates the config based on settings.
Rewrites rollout_fragment_length to take into account n_step truncation.
"""
if config["entropy_coeff"] < 0:
raise ValueError("`entropy_coeff` must be >= 0.0!")
if config["num_workers"] <= 0 and config["sample_async"]:
raise ValueError("`num_workers` for A3C must be >= 1!")
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
dataflow.
Args:
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
of the Trainer.
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
assert len(kwargs) == 0, (
"A3C execution_plan does NOT take any additional parameters")
# For A3C, compute policy gradients remotely on the rollout workers.
grads = AsyncGradients(workers)
# Apply the gradients as they arrive. We set update_all to False so that
# only the worker sending the gradient is updated with new weights.
train_op = grads.for_each(ApplyGradients(workers, update_all=False))
return StandardMetricsReporting(train_op, workers, config)
A3CTrainer = build_trainer(
name="A3C",
default_config=DEFAULT_CONFIG,
default_policy=A3CTFPolicy,
get_policy_class=get_policy_class,
validate_config=validate_config,
execution_plan=execution_plan)