[RLlib] MARWIL: Move to training_iteration API. (#23798)

This commit is contained in:
Sven Mika 2022-04-11 19:28:32 +02:00 committed by GitHub
parent 5dc958037e
commit a3d4fc74a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 12 deletions

View file

@ -7,11 +7,30 @@ from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentRepl
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import TrainOneStep
from ray.rllib.execution.rollout_ops import (
ConcatBatches,
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
TrainOneStep,
train_one_step,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
WORKER_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
ResultDict,
TrainerConfigDict,
)
from ray.util.iter import LocalIterator
# fmt: off
@ -71,6 +90,9 @@ DEFAULT_CONFIG = with_common_config({
# === Parallelism ===
"num_workers": 0,
# Use new `training_iteration` API (instead of `execution_plan` method).
"_disable_execution_plan_api": True,
})
# __sphinx_doc_end__
# fmt: on
@ -105,6 +127,54 @@ class MARWILTrainer(Trainer):
else:
return MARWILTFPolicy
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
super().setup(config)
# `training_iteration` implementation: Setup buffer in `setup`, not
# in `execution_plan` (deprecated).
if self.config["_disable_execution_plan_api"] is True:
self.local_replay_buffer = MultiAgentReplayBuffer(
learning_starts=self.config["learning_starts"],
capacity=self.config["replay_buffer_size"],
replay_batch_size=self.config["train_batch_size"],
replay_sequence_length=1,
)
@override(Trainer)
def training_iteration(self) -> ResultDict:
# Collect SampleBatches from sample workers.
batch = synchronous_parallel_sample(worker_set=self.workers)
batch = batch.as_multi_agent()
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
# Add batch to replay buffer.
self.local_replay_buffer.add_batch(batch)
# Pull batch from replay buffer and train on it.
train_batch = self.local_replay_buffer.replay()
# Train.
if self.config["simple_optimizer"]:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
self._counters[NUM_AGENT_STEPS_TRAINED] += batch.agent_steps()
self._counters[NUM_ENV_STEPS_TRAINED] += batch.env_steps()
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}
# Update weights - after learning on the local worker - on all remote
# workers.
if self.workers.remote_workers():
with self._timers[WORKER_UPDATE_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
# Update global vars on local worker as well.
self.workers.local_worker().set_global_vars(global_vars)
return train_results
@staticmethod
@override(Trainer)
def execution_plan(

View file

@ -14,8 +14,7 @@ DEFAULT_CONFIG = with_common_config({
# Experimental: By default, switch off preprocessors for PG.
"_disable_preprocessor_api": True,
# PG is the first algo (experimental) to not use the distr. exec API
# anymore.
# Use new `training_iteration` API (instead of `execution_plan` method).
"_disable_execution_plan_api": True,
})

View file

@ -17,7 +17,6 @@ Note that unlike the paper, we currently do not implement straggler mitigation.
"""
import logging
import sys
import time
from typing import Callable, Optional, Union
@ -158,12 +157,6 @@ class DDPPOTrainer(PPOTrainer):
# setting.
super().validate_config(config)
# Error if run on Win.
if sys.platform in ["win32", "cygwin"]:
raise ValueError(
"DD-PPO not supported on Win yet! Due to usage of torch.distributed."
)
# Only supported for PyTorch so far.
if config["framework"] != "torch":
raise ValueError("Distributed data parallel is only supported for PyTorch")