mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* [RLlib] Unify the way we create and use LocalReplayBuffer for all the agents. This change 1. Get rid of the try...except clause when we call execution_plan(), and get rid of the Deprecation warning as a result. 2. Fix the execution_plan() call in Trainer._try_recover() too. 3. Most importantly, makes it much easier to create and use different types of local replay buffers for all our agents. E.g., allow us to easily create a reservoir sampling replay buffer for APPO agent for Riot in the near future. * Introduce explicit configuration for replay buffer types. * Fix is_training key error. * actually deprecate buffer_size field.
150 lines
5.7 KiB
Python
150 lines
5.7 KiB
Python
from typing import Optional, Type
|
|
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
|
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
|
|
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
|
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
|
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
|
from ray.rllib.execution.concurrency_ops import Concurrently
|
|
from ray.rllib.execution.train_ops import TrainOneStep
|
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
from ray.util.iter import LocalIterator
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
# yapf: disable
|
|
# __sphinx_doc_begin__
|
|
DEFAULT_CONFIG = with_common_config({
|
|
# === Input settings ===
|
|
# You should override this to point to an offline dataset
|
|
# (see trainer.py).
|
|
# The dataset may have an arbitrary number of timesteps
|
|
# (and even episodes) per line.
|
|
# However, each line must only contain consecutive timesteps in
|
|
# order for MARWIL to be able to calculate accumulated
|
|
# discounted returns. It is ok, though, to have multiple episodes in
|
|
# the same line.
|
|
"input": "sampler",
|
|
# Use importance sampling estimators for reward.
|
|
"input_evaluation": ["is", "wis"],
|
|
|
|
# === Postprocessing/accum., discounted return calculation ===
|
|
# If true, use the Generalized Advantage Estimator (GAE)
|
|
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf in
|
|
# case an input line ends with a non-terminal timestep.
|
|
"use_gae": True,
|
|
# Whether to calculate cumulative rewards. Must be True.
|
|
"postprocess_inputs": True,
|
|
|
|
# === Training ===
|
|
# Scaling of advantages in exponential terms.
|
|
# When beta is 0.0, MARWIL is reduced to behavior cloning
|
|
# (imitation learning); see bc.py algorithm in this same directory.
|
|
"beta": 1.0,
|
|
# Balancing value estimation loss and policy optimization loss.
|
|
"vf_coeff": 1.0,
|
|
# If specified, clip the global norm of gradients by this amount.
|
|
"grad_clip": None,
|
|
# Learning rate for Adam optimizer.
|
|
"lr": 1e-4,
|
|
# The squared moving avg. advantage norm (c^2) update rate
|
|
# (1e-8 in the paper).
|
|
"moving_average_sqd_adv_norm_update_rate": 1e-8,
|
|
# Starting value for the squared moving avg. advantage norm (c^2).
|
|
"moving_average_sqd_adv_norm_start": 100.0,
|
|
# Number of (independent) timesteps pushed through the loss
|
|
# each SGD round.
|
|
"train_batch_size": 2000,
|
|
# Size of the replay buffer in (single and independent) timesteps.
|
|
# The buffer gets filled by reading from the input files line-by-line
|
|
# and adding all timesteps on one line at once. We then sample
|
|
# uniformly from the buffer (`train_batch_size` samples) for
|
|
# each training step.
|
|
"replay_buffer_size": 10000,
|
|
# Number of steps to read before learning starts.
|
|
"learning_starts": 0,
|
|
|
|
# === Parallelism ===
|
|
"num_workers": 0,
|
|
})
|
|
# __sphinx_doc_end__
|
|
# yapf: enable
|
|
|
|
|
|
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
|
"""Policy class picker function. Class is chosen based on DL-framework.
|
|
MARWIL/BC have both TF and Torch policy support.
|
|
|
|
Args:
|
|
config (TrainerConfigDict): The trainer's configuration dict.
|
|
|
|
Returns:
|
|
Optional[Type[Policy]]: The Policy class to use with DQNTrainer.
|
|
If None, use `default_policy` provided in build_trainer().
|
|
"""
|
|
if config["framework"] == "torch":
|
|
from ray.rllib.agents.marwil.marwil_torch_policy import \
|
|
MARWILTorchPolicy
|
|
return MARWILTorchPolicy
|
|
|
|
|
|
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
|
**kwargs) -> LocalIterator[dict]:
|
|
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
|
|
dataflow.
|
|
|
|
Args:
|
|
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
|
of the Trainer.
|
|
config (TrainerConfigDict): The trainer's configuration dict.
|
|
|
|
Returns:
|
|
LocalIterator[dict]: A local iterator over training metrics.
|
|
"""
|
|
assert len(kwargs) == 0, (
|
|
"Marwill execution_plan does NOT take any additional parameters")
|
|
|
|
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
|
replay_buffer = LocalReplayBuffer(
|
|
learning_starts=config["learning_starts"],
|
|
capacity=config["replay_buffer_size"],
|
|
replay_batch_size=config["train_batch_size"],
|
|
replay_sequence_length=1,
|
|
)
|
|
|
|
store_op = rollouts \
|
|
.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
|
|
|
|
replay_op = Replay(local_buffer=replay_buffer) \
|
|
.combine(
|
|
ConcatBatches(
|
|
min_batch_size=config["train_batch_size"],
|
|
count_steps_by=config["multiagent"]["count_steps_by"],
|
|
)) \
|
|
.for_each(TrainOneStep(workers))
|
|
|
|
train_op = Concurrently(
|
|
[store_op, replay_op], mode="round_robin", output_indexes=[1])
|
|
|
|
return StandardMetricsReporting(train_op, workers, config)
|
|
|
|
|
|
def validate_config(config: TrainerConfigDict) -> None:
|
|
"""Checks and updates the config based on settings."""
|
|
if config["num_gpus"] > 1:
|
|
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")
|
|
|
|
if config["postprocess_inputs"] is False and config["beta"] > 0.0:
|
|
raise ValueError("`postprocess_inputs` must be True for MARWIL (to "
|
|
"calculate accum., discounted returns)!")
|
|
|
|
|
|
MARWILTrainer = build_trainer(
|
|
name="MARWIL",
|
|
default_config=DEFAULT_CONFIG,
|
|
default_policy=MARWILTFPolicy,
|
|
get_policy_class=get_policy_class,
|
|
validate_config=validate_config,
|
|
execution_plan=execution_plan)
|